mirror of https://github.com/microsoft/autogen.git
Add Session Saving to AGS (#4369)
* fix import issue related to agentchat update #4245 * update uv lock file * fix db auto_upgrade logic issue. * im prove msg rendering issue * Support termination condition combination. Closes #4325 * fix db instantiation bug * update yarn.lock, closes #4260 #4262 * remove deps for now with vulnerabilities found by dependabot #4262 * update db tests * add ability to load sessions from db .. * format updates, add format checks to ags * format check fixes * linting and ruff check fixes * make tests for ags non-parrallel to avoid db race conditions. * format updates * fix concurrency issue * minor ui tweaks, move run start to websocket * lint fixes * update uv.lock * Update python/packages/autogen-studio/autogenstudio/datamodel/types.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/packages/autogen-studio/autogenstudio/teammanager.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * reuse user proxy from agentchat * ui tweaks --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
parent
df183be35a
commit
fe96f7de24
|
@ -1,3 +1,18 @@
|
|||
from .datamodel import *
|
||||
from .database.db_manager import DatabaseManager
|
||||
from .datamodel import Agent, AgentConfig, Model, ModelConfig, Team, TeamConfig, Tool, ToolConfig
|
||||
from .teammanager import TeamManager
|
||||
from .version import __version__
|
||||
from .teammanager import *
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"Model",
|
||||
"DatabaseManager",
|
||||
"Team",
|
||||
"Agent",
|
||||
"ToolConfig",
|
||||
"ModelConfig",
|
||||
"TeamConfig",
|
||||
"AgentConfig",
|
||||
"TeamManager",
|
||||
"__version__",
|
||||
]
|
||||
|
|
|
@ -15,7 +15,7 @@ def ui(
|
|||
host: str = "127.0.0.1",
|
||||
port: int = 8081,
|
||||
workers: int = 1,
|
||||
reload: Annotated[bool, typer.Option("--reload")] = True,
|
||||
reload: Annotated[bool, typer.Option("--reload")] = False,
|
||||
docs: bool = True,
|
||||
appdir: str = None,
|
||||
database_uri: Optional[str] = None,
|
||||
|
@ -48,11 +48,7 @@ def ui(
|
|||
port=port,
|
||||
workers=workers,
|
||||
reload=reload,
|
||||
reload_excludes=[
|
||||
"**/alembic/*",
|
||||
"**/alembic.ini",
|
||||
"**/versions/*"
|
||||
] if reload else None
|
||||
reload_excludes=["**/alembic/*", "**/alembic.ini", "**/versions/*"] if reload else None,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .agents.userproxy import UserProxyAgent
|
|
@ -1,47 +0,0 @@
|
|||
from typing import Callable, List, Optional, Sequence, Union, Awaitable
|
||||
from inspect import iscoroutinefunction
|
||||
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.messages import ChatMessage, TextMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
import asyncio
|
||||
|
||||
|
||||
class UserProxyAgent(BaseChatAgent):
|
||||
"""An agent that can represent a human user in a chat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str] = "a",
|
||||
input_func: Optional[Union[Callable[..., str],
|
||||
Callable[..., Awaitable[str]]]] = None
|
||||
) -> None:
|
||||
super().__init__(name, description=description)
|
||||
self.input_func = input_func or input
|
||||
self._is_async = iscoroutinefunction(
|
||||
input_func) if input_func else False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
|
||||
async def _get_input(self, prompt: str) -> str:
|
||||
"""Handle both sync and async input functions"""
|
||||
if self._is_async:
|
||||
return await self.input_func(prompt)
|
||||
else:
|
||||
return await asyncio.get_event_loop().run_in_executor(None, self.input_func, prompt)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
|
||||
try:
|
||||
user_input = await self._get_input("Enter your response: ")
|
||||
return Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
except Exception as e:
|
||||
# Consider logging the error here
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
pass
|
|
@ -1,3 +1,3 @@
|
|||
from .db_manager import DatabaseManager
|
||||
from .component_factory import ComponentFactory, Component
|
||||
from .component_factory import Component, ComponentFactory
|
||||
from .config_manager import ConfigurationManager
|
||||
from .db_manager import DatabaseManager
|
||||
|
|
|
@ -1,43 +1,45 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Literal, Union, Optional, Dict, Any, Type
|
||||
from datetime import datetime
|
||||
import json
|
||||
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination, StopMessageTermination
|
||||
import yaml
|
||||
import logging
|
||||
from packaging import version
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from ..datamodel import (
|
||||
TeamConfig, AgentConfig, ModelConfig, ToolConfig,
|
||||
TeamTypes, AgentTypes, ModelTypes, ToolTypes,
|
||||
ComponentType, ComponentConfig, ComponentConfigInput, TerminationConfig, TerminationTypes, Response
|
||||
)
|
||||
from ..components import UserProxyAgent
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
import aiofiles
|
||||
import yaml
|
||||
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
|
||||
from ..datamodel.types import (
|
||||
AgentConfig,
|
||||
AgentTypes,
|
||||
ComponentConfig,
|
||||
ComponentConfigInput,
|
||||
ComponentTypes,
|
||||
ModelConfig,
|
||||
ModelTypes,
|
||||
TeamConfig,
|
||||
TeamTypes,
|
||||
TerminationConfig,
|
||||
TerminationTypes,
|
||||
ToolConfig,
|
||||
ToolTypes,
|
||||
)
|
||||
from ..utils.utils import Version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions for supported components
|
||||
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
|
||||
AgentComponent = Union[AssistantAgent] # Will grow with more agent types
|
||||
# Will grow with more model types
|
||||
AgentComponent = Union[AssistantAgent]
|
||||
ModelComponent = Union[OpenAIChatCompletionClient]
|
||||
ToolComponent = Union[FunctionTool] # Will grow with more tool types
|
||||
TerminationComponent = Union[MaxMessageTermination,
|
||||
StopMessageTermination, TextMentionTermination]
|
||||
TerminationComponent = Union[MaxMessageTermination, StopMessageTermination, TextMentionTermination]
|
||||
|
||||
# Config type definitions
|
||||
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent, TerminationComponent]
|
||||
|
||||
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent]
|
||||
|
||||
|
||||
ReturnType = Literal['object', 'dict', 'config']
|
||||
Component = Union[RoundRobinGroupChat, SelectorGroupChat,
|
||||
AssistantAgent, OpenAIChatCompletionClient, FunctionTool]
|
||||
ReturnType = Literal["object", "dict", "config"]
|
||||
|
||||
DEFAULT_SELECTOR_PROMPT = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
|
@ -48,18 +50,18 @@ Read the following conversation. Then select the next role from {participants} t
|
|||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
"""
|
||||
|
||||
CONFIG_RETURN_TYPES = Literal['object', 'dict', 'config']
|
||||
CONFIG_RETURN_TYPES = Literal["object", "dict", "config"]
|
||||
|
||||
|
||||
class ComponentFactory:
|
||||
"""Creates and manages agent components with versioned configuration loading"""
|
||||
|
||||
SUPPORTED_VERSIONS = {
|
||||
ComponentType.TEAM: ["1.0.0"],
|
||||
ComponentType.AGENT: ["1.0.0"],
|
||||
ComponentType.MODEL: ["1.0.0"],
|
||||
ComponentType.TOOL: ["1.0.0"],
|
||||
ComponentType.TERMINATION: ["1.0.0"]
|
||||
ComponentTypes.TEAM: ["1.0.0"],
|
||||
ComponentTypes.AGENT: ["1.0.0"],
|
||||
ComponentTypes.MODEL: ["1.0.0"],
|
||||
ComponentTypes.TOOL: ["1.0.0"],
|
||||
ComponentTypes.TERMINATION: ["1.0.0"],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
|
@ -68,10 +70,7 @@ class ComponentFactory:
|
|||
self._last_cache_clear = datetime.now()
|
||||
|
||||
async def load(
|
||||
self,
|
||||
component: ComponentConfigInput,
|
||||
input_func: Optional[Callable] = None,
|
||||
return_type: ReturnType = 'object'
|
||||
self, component: ComponentConfigInput, input_func: Optional[Callable] = None, return_type: ReturnType = "object"
|
||||
) -> Union[Component, dict, ComponentConfig]:
|
||||
"""
|
||||
Universal loader for any component type
|
||||
|
@ -103,24 +102,23 @@ class ComponentFactory:
|
|||
)
|
||||
|
||||
# Return early if dict or config requested
|
||||
if return_type == 'dict':
|
||||
if return_type == "dict":
|
||||
return config.model_dump()
|
||||
elif return_type == 'config':
|
||||
elif return_type == "config":
|
||||
return config
|
||||
|
||||
# Otherwise create and return component instance
|
||||
handlers = {
|
||||
ComponentType.TEAM: lambda c: self.load_team(c, input_func),
|
||||
ComponentType.AGENT: lambda c: self.load_agent(c, input_func),
|
||||
ComponentType.MODEL: self.load_model,
|
||||
ComponentType.TOOL: self.load_tool,
|
||||
ComponentType.TERMINATION: self.load_termination
|
||||
ComponentTypes.TEAM: lambda c: self.load_team(c, input_func),
|
||||
ComponentTypes.AGENT: lambda c: self.load_agent(c, input_func),
|
||||
ComponentTypes.MODEL: self.load_model,
|
||||
ComponentTypes.TOOL: self.load_tool,
|
||||
ComponentTypes.TERMINATION: self.load_termination,
|
||||
}
|
||||
|
||||
handler = handlers.get(config.component_type)
|
||||
if not handler:
|
||||
raise ValueError(
|
||||
f"Unknown component type: {config.component_type}")
|
||||
raise ValueError(f"Unknown component type: {config.component_type}")
|
||||
|
||||
return await handler(config)
|
||||
|
||||
|
@ -128,7 +126,9 @@ class ComponentFactory:
|
|||
logger.error(f"Failed to load component: {str(e)}")
|
||||
raise
|
||||
|
||||
async def load_directory(self, directory: Union[str, Path], return_type: ReturnType = 'object') -> List[Union[Component, dict, ComponentConfig]]:
|
||||
async def load_directory(
|
||||
self, directory: Union[str, Path], return_type: ReturnType = "object"
|
||||
) -> List[Union[Component, dict, ComponentConfig]]:
|
||||
"""
|
||||
Import all component configurations from a directory.
|
||||
"""
|
||||
|
@ -137,13 +137,12 @@ class ComponentFactory:
|
|||
directory = Path(directory)
|
||||
# Using Path.iterdir() instead of os.listdir
|
||||
for path in list(directory.glob("*")):
|
||||
if path.suffix.lower().endswith(('.json', '.yaml', '.yml')):
|
||||
if path.suffix.lower().endswith((".json", ".yaml", ".yml")):
|
||||
try:
|
||||
component = await self.load(path, return_type=return_type)
|
||||
components.append(component)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Failed to load component: {str(e)}, {path}")
|
||||
logger.info(f"Failed to load component: {str(e)}, {path}")
|
||||
|
||||
return components
|
||||
except Exception as e:
|
||||
|
@ -156,14 +155,14 @@ class ComponentFactory:
|
|||
raise ValueError("component_type is required in configuration")
|
||||
|
||||
config_types = {
|
||||
ComponentType.TEAM: TeamConfig,
|
||||
ComponentType.AGENT: AgentConfig,
|
||||
ComponentType.MODEL: ModelConfig,
|
||||
ComponentType.TOOL: ToolConfig,
|
||||
ComponentType.TERMINATION: TerminationConfig # Add mapping for termination
|
||||
ComponentTypes.TEAM: TeamConfig,
|
||||
ComponentTypes.AGENT: AgentConfig,
|
||||
ComponentTypes.MODEL: ModelConfig,
|
||||
ComponentTypes.TOOL: ToolConfig,
|
||||
ComponentTypes.TERMINATION: TerminationConfig, # Add mapping for termination
|
||||
}
|
||||
|
||||
component_type = ComponentType(config_dict["component_type"])
|
||||
component_type = ComponentTypes(config_dict["component_type"])
|
||||
config_class = config_types.get(component_type)
|
||||
|
||||
if not config_class:
|
||||
|
@ -174,28 +173,44 @@ class ComponentFactory:
|
|||
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
|
||||
"""Create termination condition instance from configuration."""
|
||||
try:
|
||||
if config.termination_type == TerminationTypes.MAX_MESSAGES:
|
||||
if config.termination_type == TerminationTypes.COMBINATION:
|
||||
if not config.conditions or len(config.conditions) < 2:
|
||||
raise ValueError("Combination termination requires at least 2 conditions")
|
||||
if not config.operator:
|
||||
raise ValueError("Combination termination requires an operator (and/or)")
|
||||
|
||||
# Load first two conditions
|
||||
conditions = [await self.load_termination(cond) for cond in config.conditions[:2]]
|
||||
result = conditions[0] & conditions[1] if config.operator == "and" else conditions[0] | conditions[1]
|
||||
|
||||
# Process remaining conditions if any
|
||||
for condition in config.conditions[2:]:
|
||||
next_condition = await self.load_termination(condition)
|
||||
result = result & next_condition if config.operator == "and" else result | next_condition
|
||||
|
||||
return result
|
||||
|
||||
elif config.termination_type == TerminationTypes.MAX_MESSAGES:
|
||||
if config.max_messages is None:
|
||||
raise ValueError("max_messages parameter required for MaxMessageTermination")
|
||||
return MaxMessageTermination(max_messages=config.max_messages)
|
||||
|
||||
elif config.termination_type == TerminationTypes.STOP_MESSAGE:
|
||||
return StopMessageTermination()
|
||||
|
||||
elif config.termination_type == TerminationTypes.TEXT_MENTION:
|
||||
if not config.text:
|
||||
raise ValueError(
|
||||
"text parameter required for TextMentionTermination")
|
||||
raise ValueError("text parameter required for TextMentionTermination")
|
||||
return TextMentionTermination(text=config.text)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported termination type: {config.termination_type}")
|
||||
raise ValueError(f"Unsupported termination type: {config.termination_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create termination condition: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Termination condition creation failed: {str(e)}")
|
||||
raise ValueError(f"Termination condition creation failed: {str(e)}") from e
|
||||
|
||||
async def load_team(
|
||||
self,
|
||||
config: TeamConfig,
|
||||
input_func: Optional[Callable] = None
|
||||
) -> TeamComponent:
|
||||
async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = None) -> TeamComponent:
|
||||
"""Create team instance from configuration."""
|
||||
try:
|
||||
# Load participants (agents) with input_func
|
||||
|
@ -216,33 +231,25 @@ class ComponentFactory:
|
|||
|
||||
# Create team based on type
|
||||
if config.team_type == TeamTypes.ROUND_ROBIN:
|
||||
return RoundRobinGroupChat(
|
||||
participants=participants,
|
||||
termination_condition=termination
|
||||
)
|
||||
return RoundRobinGroupChat(participants=participants, termination_condition=termination)
|
||||
elif config.team_type == TeamTypes.SELECTOR:
|
||||
if not model_client:
|
||||
raise ValueError(
|
||||
"SelectorGroupChat requires a model_client")
|
||||
raise ValueError("SelectorGroupChat requires a model_client")
|
||||
selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT
|
||||
return SelectorGroupChat(
|
||||
participants=participants,
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
selector_prompt=selector_prompt
|
||||
selector_prompt=selector_prompt,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported team type: {config.team_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create team {config.name}: {str(e)}")
|
||||
raise ValueError(f"Team creation failed: {str(e)}")
|
||||
raise ValueError(f"Team creation failed: {str(e)}") from e
|
||||
|
||||
async def load_agent(
|
||||
self,
|
||||
config: AgentConfig,
|
||||
input_func: Optional[Callable] = None
|
||||
) -> AgentComponent:
|
||||
async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] = None) -> AgentComponent:
|
||||
"""Create agent instance from configuration."""
|
||||
try:
|
||||
# Load model client if specified
|
||||
|
@ -263,7 +270,7 @@ class ComponentFactory:
|
|||
return UserProxyAgent(
|
||||
name=config.name,
|
||||
description=config.description or "A human user",
|
||||
input_func=input_func # Pass through to UserProxyAgent
|
||||
input_func=input_func, # Pass through to UserProxyAgent
|
||||
)
|
||||
elif config.agent_type == AgentTypes.ASSISTANT:
|
||||
return AssistantAgent(
|
||||
|
@ -271,15 +278,14 @@ class ComponentFactory:
|
|||
description=config.description or "A helpful assistant",
|
||||
model_client=model_client,
|
||||
tools=tools,
|
||||
system_message=system_message
|
||||
system_message=system_message,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported agent type: {config.agent_type}")
|
||||
raise ValueError(f"Unsupported agent type: {config.agent_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent {config.name}: {str(e)}")
|
||||
raise ValueError(f"Agent creation failed: {str(e)}")
|
||||
raise ValueError(f"Agent creation failed: {str(e)}") from e
|
||||
|
||||
async def load_model(self, config: ModelConfig) -> ModelComponent:
|
||||
"""Create model instance from configuration."""
|
||||
|
@ -291,20 +297,15 @@ class ComponentFactory:
|
|||
return self._model_cache[cache_key]
|
||||
|
||||
if config.model_type == ModelTypes.OPENAI:
|
||||
model = OpenAIChatCompletionClient(
|
||||
model=config.model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url
|
||||
)
|
||||
model = OpenAIChatCompletionClient(model=config.model, api_key=config.api_key, base_url=config.base_url)
|
||||
self._model_cache[cache_key] = model
|
||||
return model
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model type: {config.model_type}")
|
||||
raise ValueError(f"Unsupported model type: {config.model_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model {config.model}: {str(e)}")
|
||||
raise ValueError(f"Model creation failed: {str(e)}")
|
||||
raise ValueError(f"Model creation failed: {str(e)}") from e
|
||||
|
||||
async def load_tool(self, config: ToolConfig) -> ToolComponent:
|
||||
"""Create tool instance from configuration."""
|
||||
|
@ -321,9 +322,7 @@ class ComponentFactory:
|
|||
|
||||
if config.tool_type == ToolTypes.PYTHON_FUNCTION:
|
||||
tool = FunctionTool(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
func=self._func_from_string(config.content)
|
||||
name=config.name, description=config.description, func=self._func_from_string(config.content)
|
||||
)
|
||||
self._tool_cache[cache_key] = tool
|
||||
return tool
|
||||
|
@ -334,7 +333,6 @@ class ComponentFactory:
|
|||
logger.error(f"Failed to create tool '{config.name}': {str(e)}")
|
||||
raise
|
||||
|
||||
# Helper methods remain largely the same
|
||||
async def _load_from_file(self, path: Union[str, Path]) -> dict:
|
||||
"""Load configuration from JSON or YAML file."""
|
||||
path = Path(path)
|
||||
|
@ -342,15 +340,16 @@ class ComponentFactory:
|
|||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
if path.suffix == '.json':
|
||||
return json.load(f)
|
||||
elif path.suffix in ('.yml', '.yaml'):
|
||||
return yaml.safe_load(f)
|
||||
async with aiofiles.open(path) as f:
|
||||
content = await f.read()
|
||||
if path.suffix == ".json":
|
||||
return json.loads(content)
|
||||
elif path.suffix in (".yml", ".yaml"):
|
||||
return yaml.safe_load(content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {path.suffix}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load file {path}: {str(e)}")
|
||||
raise ValueError(f"Failed to load file {path}: {str(e)}") from e
|
||||
|
||||
def _func_from_string(self, content: str) -> callable:
|
||||
"""Convert function string to callable."""
|
||||
|
@ -362,24 +361,25 @@ class ComponentFactory:
|
|||
return item
|
||||
raise ValueError("No function found in provided code")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create function: {str(e)}")
|
||||
raise ValueError(f"Failed to create function: {str(e)}") from e
|
||||
|
||||
def _is_version_supported(self, component_type: ComponentType, ver: str) -> bool:
|
||||
def _is_version_supported(self, component_type: ComponentTypes, ver: str) -> bool:
|
||||
"""Check if version is supported for component type."""
|
||||
try:
|
||||
v = version.parse(ver)
|
||||
return ver in self.SUPPORTED_VERSIONS[component_type]
|
||||
except version.InvalidVersion:
|
||||
version = Version(ver)
|
||||
supported = [Version(v) for v in self.SUPPORTED_VERSIONS[component_type]]
|
||||
return any(version == v for v in supported)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup resources and clear caches."""
|
||||
for model in self._model_cache.values():
|
||||
if hasattr(model, 'cleanup'):
|
||||
if hasattr(model, "cleanup"):
|
||||
await model.cleanup()
|
||||
|
||||
for tool in self._tool_cache.values():
|
||||
if hasattr(tool, 'cleanup'):
|
||||
if hasattr(tool, "cleanup"):
|
||||
await tool.cleanup()
|
||||
|
||||
self._model_cache.clear()
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import logging
|
||||
from typing import Optional, Union, Dict, Any, List
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from ..datamodel import (
|
||||
Model, Team, Agent, Tool,
|
||||
Response, ComponentTypes, LinkTypes,
|
||||
ComponentConfigInput
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..datamodel.db import Agent, LinkTypes, Model, Team, Tool
|
||||
from ..datamodel.types import ComponentConfigInput, ComponentTypes, Response
|
||||
from .component_factory import ComponentFactory
|
||||
from .db_manager import DatabaseManager
|
||||
|
||||
|
@ -16,10 +14,10 @@ class ConfigurationManager:
|
|||
"""Manages persistence and relationships of components using ComponentFactory for validation"""
|
||||
|
||||
DEFAULT_UNIQUENESS_FIELDS = {
|
||||
ComponentTypes.MODEL: ['model_type', 'model'],
|
||||
ComponentTypes.TOOL: ['name'],
|
||||
ComponentTypes.AGENT: ['agent_type', 'name'],
|
||||
ComponentTypes.TEAM: ['team_type', 'name']
|
||||
ComponentTypes.MODEL: ["model_type", "model"],
|
||||
ComponentTypes.TOOL: ["name"],
|
||||
ComponentTypes.AGENT: ["agent_type", "name"],
|
||||
ComponentTypes.TEAM: ["team_type", "name"],
|
||||
}
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, uniqueness_fields: Dict[ComponentTypes, List[str]] = None):
|
||||
|
@ -27,7 +25,9 @@ class ConfigurationManager:
|
|||
self.component_factory = ComponentFactory()
|
||||
self.uniqueness_fields = uniqueness_fields or self.DEFAULT_UNIQUENESS_FIELDS
|
||||
|
||||
async def import_component(self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False) -> Response:
|
||||
async def import_component(
|
||||
self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False
|
||||
) -> Response:
|
||||
"""
|
||||
Import a component configuration, validate it, and store the resulting component.
|
||||
|
||||
|
@ -41,23 +41,21 @@ class ConfigurationManager:
|
|||
"""
|
||||
try:
|
||||
# Get validated config as dict
|
||||
config = await self.component_factory.load(component_config, return_type='dict')
|
||||
config = await self.component_factory.load(component_config, return_type="dict")
|
||||
|
||||
# Get component type
|
||||
component_type = self._determine_component_type(config)
|
||||
if not component_type:
|
||||
raise ValueError(
|
||||
f"Unable to determine component type from config")
|
||||
raise ValueError("Unable to determine component type from config")
|
||||
|
||||
# Check existence if requested
|
||||
if check_exists:
|
||||
existing = self._check_exists(component_type, config, user_id)
|
||||
if existing:
|
||||
return Response(
|
||||
message=self._format_exists_message(
|
||||
component_type, config),
|
||||
message=self._format_exists_message(component_type, config),
|
||||
status=True,
|
||||
data={"id": existing.id}
|
||||
data={"id": existing.id},
|
||||
)
|
||||
|
||||
# Route to appropriate storage method
|
||||
|
@ -70,8 +68,7 @@ class ConfigurationManager:
|
|||
elif component_type == ComponentTypes.TOOL:
|
||||
return await self._store_tool(config, user_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported component type: {component_type}")
|
||||
raise ValueError(f"Unsupported component type: {component_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import component: {str(e)}")
|
||||
|
@ -90,23 +87,21 @@ class ConfigurationManager:
|
|||
Response containing import results for all files
|
||||
"""
|
||||
try:
|
||||
configs = await self.component_factory.load_directory(directory, return_type='dict')
|
||||
configs = await self.component_factory.load_directory(directory, return_type="dict")
|
||||
|
||||
results = []
|
||||
for config in configs:
|
||||
result = await self.import_component(config, user_id, check_exists)
|
||||
results.append({
|
||||
"component": self._get_component_type(config),
|
||||
"status": result.status,
|
||||
"message": result.message,
|
||||
"id": result.data.get("id") if result.status else None
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"component": self._get_component_type(config),
|
||||
"status": result.status,
|
||||
"message": result.message,
|
||||
"id": result.data.get("id") if result.status else None,
|
||||
}
|
||||
)
|
||||
|
||||
return Response(
|
||||
message="Directory import complete",
|
||||
status=True,
|
||||
data=results
|
||||
)
|
||||
return Response(message="Directory import complete", status=True, data=results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import directory: {str(e)}")
|
||||
|
@ -116,10 +111,7 @@ class ConfigurationManager:
|
|||
"""Store team component and manage its relationships with agents"""
|
||||
try:
|
||||
# Store the team
|
||||
team_db = Team(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
team_db = Team(user_id=user_id, config=config)
|
||||
team_result = self.db_manager.upsert(team_db)
|
||||
if not team_result.status:
|
||||
return team_result
|
||||
|
@ -131,27 +123,17 @@ class ConfigurationManager:
|
|||
if check_exists:
|
||||
# Check for existing agent
|
||||
agent_type = self._determine_component_type(participant)
|
||||
existing_agent = self._check_exists(
|
||||
agent_type, participant, user_id)
|
||||
existing_agent = self._check_exists(agent_type, participant, user_id)
|
||||
if existing_agent:
|
||||
# Link existing agent
|
||||
self.db_manager.link(
|
||||
LinkTypes.TEAM_AGENT,
|
||||
team_id,
|
||||
existing_agent.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing agent to team: {existing_agent}")
|
||||
self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, existing_agent.id)
|
||||
logger.info(f"Linked existing agent to team: {existing_agent}")
|
||||
continue
|
||||
|
||||
# Store and link new agent
|
||||
agent_result = await self._store_agent(participant, user_id, check_exists)
|
||||
if agent_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.TEAM_AGENT,
|
||||
team_id,
|
||||
agent_result.data["id"]
|
||||
)
|
||||
self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, agent_result.data["id"])
|
||||
|
||||
return team_result
|
||||
|
||||
|
@ -163,10 +145,7 @@ class ConfigurationManager:
|
|||
"""Store agent component and manage its relationships with tools and model"""
|
||||
try:
|
||||
# Store the agent
|
||||
agent_db = Agent(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
agent_db = Agent(user_id=user_id, config=config)
|
||||
agent_result = self.db_manager.upsert(agent_db)
|
||||
if not agent_result.status:
|
||||
return agent_result
|
||||
|
@ -177,64 +156,39 @@ class ConfigurationManager:
|
|||
if "model_client" in config:
|
||||
if check_exists:
|
||||
# Check for existing model
|
||||
model_type = self._determine_component_type(
|
||||
config["model_client"])
|
||||
existing_model = self._check_exists(
|
||||
model_type, config["model_client"], user_id)
|
||||
model_type = self._determine_component_type(config["model_client"])
|
||||
existing_model = self._check_exists(model_type, config["model_client"], user_id)
|
||||
if existing_model:
|
||||
# Link existing model
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
existing_model.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing model to agent: {existing_model.config.model_type}")
|
||||
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, existing_model.id)
|
||||
logger.info(f"Linked existing model to agent: {existing_model.config.model_type}")
|
||||
else:
|
||||
# Store and link new model
|
||||
model_result = await self._store_model(config["model_client"], user_id)
|
||||
if model_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
model_result.data["id"]
|
||||
)
|
||||
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"])
|
||||
else:
|
||||
# Store and link new model without checking
|
||||
model_result = await self._store_model(config["model_client"], user_id)
|
||||
if model_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
model_result.data["id"]
|
||||
)
|
||||
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"])
|
||||
|
||||
# Handle tools
|
||||
for tool_config in config.get("tools", []):
|
||||
if check_exists:
|
||||
# Check for existing tool
|
||||
tool_type = self._determine_component_type(tool_config)
|
||||
existing_tool = self._check_exists(
|
||||
tool_type, tool_config, user_id)
|
||||
existing_tool = self._check_exists(tool_type, tool_config, user_id)
|
||||
if existing_tool:
|
||||
# Link existing tool
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_TOOL,
|
||||
agent_id,
|
||||
existing_tool.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing tool to agent: {existing_tool.config.name}")
|
||||
self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, existing_tool.id)
|
||||
logger.info(f"Linked existing tool to agent: {existing_tool.config.name}")
|
||||
continue
|
||||
|
||||
# Store and link new tool
|
||||
tool_result = await self._store_tool(tool_config, user_id)
|
||||
if tool_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_TOOL,
|
||||
agent_id,
|
||||
tool_result.data["id"]
|
||||
)
|
||||
self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, tool_result.data["id"])
|
||||
|
||||
return agent_result
|
||||
|
||||
|
@ -245,10 +199,7 @@ class ConfigurationManager:
|
|||
async def _store_model(self, config: dict, user_id: str) -> Response:
|
||||
"""Store model component (leaf node - no relationships)"""
|
||||
try:
|
||||
model_db = Model(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
model_db = Model(user_id=user_id, config=config)
|
||||
return self.db_manager.upsert(model_db)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -258,17 +209,16 @@ class ConfigurationManager:
|
|||
async def _store_tool(self, config: dict, user_id: str) -> Response:
|
||||
"""Store tool component (leaf node - no relationships)"""
|
||||
try:
|
||||
tool_db = Tool(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
tool_db = Tool(user_id=user_id, config=config)
|
||||
return self.db_manager.upsert(tool_db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store tool: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
def _check_exists(self, component_type: ComponentTypes, config: dict, user_id: str) -> Optional[Union[Model, Tool, Agent, Team]]:
|
||||
def _check_exists(
|
||||
self, component_type: ComponentTypes, config: dict, user_id: str
|
||||
) -> Optional[Union[Model, Tool, Agent, Team]]:
|
||||
"""Check if component exists based on configured uniqueness fields."""
|
||||
fields = self.uniqueness_fields.get(component_type, [])
|
||||
if not fields:
|
||||
|
@ -278,17 +228,13 @@ class ConfigurationManager:
|
|||
ComponentTypes.MODEL: Model,
|
||||
ComponentTypes.TOOL: Tool,
|
||||
ComponentTypes.AGENT: Agent,
|
||||
ComponentTypes.TEAM: Team
|
||||
ComponentTypes.TEAM: Team,
|
||||
}.get(component_type)
|
||||
|
||||
components = self.db_manager.get(
|
||||
component_class, {"user_id": user_id}).data
|
||||
components = self.db_manager.get(component_class, {"user_id": user_id}).data
|
||||
|
||||
for component in components:
|
||||
matches = all(
|
||||
component.config.get(field) == config.get(field)
|
||||
for field in fields
|
||||
)
|
||||
matches = all(component.config.get(field) == config.get(field) for field in fields)
|
||||
if matches:
|
||||
return component
|
||||
|
||||
|
|
|
@ -1,59 +1,77 @@
|
|||
from pathlib import Path
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc, text, func
|
||||
from sqlalchemy import exc, func, inspect, text
|
||||
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
||||
|
||||
from ..datamodel import LinkTypes, Response
|
||||
from .schema_manager import SchemaManager
|
||||
|
||||
from ..datamodel import (
|
||||
Response,
|
||||
LinkTypes
|
||||
)
|
||||
# from .dbutils import init_db_samples
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""A class to manage database operations"""
|
||||
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_uri: str,
|
||||
base_dir: Optional[Path | str] = None,
|
||||
auto_upgrade: bool = True
|
||||
):
|
||||
def __init__(self, engine_uri: str, base_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize DatabaseManager with optional custom base directory.
|
||||
Initialize DatabaseManager with database connection settings.
|
||||
Does not perform any database operations.
|
||||
|
||||
Args:
|
||||
engine_uri: Database connection URI
|
||||
base_dir: Custom base directory for Alembic files. If None, uses current working directory
|
||||
auto_upgrade: Whether to automatically upgrade schema when differences found
|
||||
engine_uri: Database connection URI (e.g. sqlite:///db.sqlite3)
|
||||
base_dir: Base directory for migration files. If None, uses current directory
|
||||
"""
|
||||
# Convert string path to Path object if necessary
|
||||
if isinstance(base_dir, str):
|
||||
base_dir = Path(base_dir)
|
||||
|
||||
connection_args = {
|
||||
"check_same_thread": True
|
||||
} if "sqlite" in engine_uri else {}
|
||||
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
|
||||
|
||||
self.engine = create_engine(engine_uri, connect_args=connection_args)
|
||||
self.schema_manager = SchemaManager(
|
||||
engine=self.engine,
|
||||
base_dir=base_dir,
|
||||
auto_upgrade=auto_upgrade,
|
||||
)
|
||||
# Check and upgrade on startup
|
||||
upgraded, status = self.schema_manager.check_and_upgrade()
|
||||
if upgraded:
|
||||
logger.info("Database schema was upgraded automatically")
|
||||
else:
|
||||
logger.info(f"Schema status: {status}")
|
||||
|
||||
def initialize_database(self, auto_upgrade: bool = False, force_init_alembic: bool = True) -> Response:
|
||||
"""
|
||||
Initialize database and migrations in the correct order.
|
||||
|
||||
Args:
|
||||
auto_upgrade: If True, automatically generate and apply migrations for schema changes
|
||||
force_init_alembic: If True, reinitialize alembic configuration even if it exists
|
||||
"""
|
||||
if not self._init_lock.acquire(blocking=False):
|
||||
return Response(message="Database initialization already in progress", status=False)
|
||||
|
||||
try:
|
||||
inspector = inspect(self.engine)
|
||||
tables_exist = inspector.get_table_names()
|
||||
|
||||
if not tables_exist:
|
||||
# Fresh install - create tables and initialize migrations
|
||||
logger.info("Creating database tables...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
if self.schema_manager.initialize_migrations(force=force_init_alembic):
|
||||
return Response(message="Database initialized successfully", status=True)
|
||||
return Response(message="Failed to initialize migrations", status=False)
|
||||
|
||||
# Handle existing database
|
||||
if auto_upgrade:
|
||||
logger.info("Checking database schema...")
|
||||
if self.schema_manager.ensure_schema_up_to_date(): # <-- Use this instead
|
||||
return Response(message="Database schema is up to date", status=True)
|
||||
return Response(message="Database upgrade failed", status=False)
|
||||
|
||||
return Response(message="Database is ready", status=True)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database initialization failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return Response(message=error_msg, status=False)
|
||||
finally:
|
||||
self._init_lock.release()
|
||||
|
||||
def reset_db(self, recreate_tables: bool = True):
|
||||
"""
|
||||
|
@ -65,11 +83,7 @@ class DatabaseManager:
|
|||
"""
|
||||
if not self._init_lock.acquire(blocking=False):
|
||||
logger.warning("Database reset already in progress")
|
||||
return Response(
|
||||
message="Database reset already in progress",
|
||||
status=False,
|
||||
data=None
|
||||
)
|
||||
return Response(message="Database reset already in progress", status=False, data=None)
|
||||
|
||||
try:
|
||||
# Dispose existing connections
|
||||
|
@ -77,16 +91,16 @@ class DatabaseManager:
|
|||
with Session(self.engine) as session:
|
||||
try:
|
||||
# Disable foreign key checks for SQLite
|
||||
if 'sqlite' in str(self.engine.url):
|
||||
session.exec(text('PRAGMA foreign_keys=OFF'))
|
||||
if "sqlite" in str(self.engine.url):
|
||||
session.exec(text("PRAGMA foreign_keys=OFF"))
|
||||
|
||||
# Drop all tables
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
logger.info("All tables dropped successfully")
|
||||
|
||||
# Re-enable foreign key checks for SQLite
|
||||
if 'sqlite' in str(self.engine.url):
|
||||
session.exec(text('PRAGMA foreign_keys=ON'))
|
||||
if "sqlite" in str(self.engine.url):
|
||||
session.exec(text("PRAGMA foreign_keys=ON"))
|
||||
|
||||
session.commit()
|
||||
|
||||
|
@ -99,48 +113,29 @@ class DatabaseManager:
|
|||
|
||||
if recreate_tables:
|
||||
logger.info("Recreating tables...")
|
||||
self.create_db_and_tables()
|
||||
self.initialize_database(auto_upgrade=False, force_init_alembic=True)
|
||||
|
||||
return Response(
|
||||
message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
|
||||
status=True,
|
||||
data=None
|
||||
data=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error while resetting database: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return Response(
|
||||
message=error_msg,
|
||||
status=False,
|
||||
data=None
|
||||
)
|
||||
return Response(message=error_msg, status=False, data=None)
|
||||
finally:
|
||||
if self._init_lock.locked():
|
||||
self._init_lock.release()
|
||||
logger.info("Database reset lock released")
|
||||
|
||||
def create_db_and_tables(self):
|
||||
"""Create a new database and tables"""
|
||||
with self._init_lock:
|
||||
try:
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("Database tables created successfully")
|
||||
try:
|
||||
# init_db_samples(self)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Error while initializing database samples: " + str(e))
|
||||
except Exception as e:
|
||||
logger.info("Error while creating database tables:" + str(e))
|
||||
|
||||
def upsert(self, model: SQLModel, return_json: bool = True):
|
||||
"""Create or update an entity
|
||||
|
||||
Args:
|
||||
model (SQLModel): The model instance to create or update
|
||||
return_json (bool, optional): If True, returns the model as a dictionary.
|
||||
return_json (bool, optional): If True, returns the model as a dictionary.
|
||||
If False, returns the SQLModel instance. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
@ -152,8 +147,7 @@ class DatabaseManager:
|
|||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
existing_model = session.exec(
|
||||
select(model_class).where(model_class.id == model.id)).first()
|
||||
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
|
||||
if existing_model:
|
||||
model.updated_at = datetime.now()
|
||||
for key, value in model.model_dump().items():
|
||||
|
@ -166,8 +160,7 @@ class DatabaseManager:
|
|||
session.refresh(model)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while updating/creating " +
|
||||
str(model_class.__name__) + ": " + str(e))
|
||||
logger.error("Error while updating/creating " + str(model_class.__name__) + ": " + str(e))
|
||||
status = False
|
||||
|
||||
return Response(
|
||||
|
@ -199,25 +192,21 @@ class DatabaseManager:
|
|||
try:
|
||||
statement = select(model_class)
|
||||
if filters:
|
||||
conditions = [getattr(model_class, col) ==
|
||||
value for col, value in filters.items()]
|
||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||
statement = statement.where(and_(*conditions))
|
||||
|
||||
if hasattr(model_class, "created_at") and order:
|
||||
order_by_clause = getattr(
|
||||
model_class.created_at, order)() # Dynamically apply asc/desc
|
||||
order_by_clause = getattr(model_class.created_at, order)() # Dynamically apply asc/desc
|
||||
statement = statement.order_by(order_by_clause)
|
||||
|
||||
items = session.exec(statement).all()
|
||||
result = [self._model_to_dict(
|
||||
item) if return_json else item for item in items]
|
||||
result = [self._model_to_dict(item) if return_json else item for item in items]
|
||||
status_message = f"{model_class.__name__} Retrieved Successfully"
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
status = False
|
||||
status_message = f"Error while fetching {model_class.__name__}"
|
||||
logger.error("Error while getting items: " +
|
||||
str(model_class.__name__) + " " + str(e))
|
||||
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
|
||||
|
||||
return Response(message=status_message, status=status, data=result)
|
||||
|
||||
|
@ -230,8 +219,7 @@ class DatabaseManager:
|
|||
try:
|
||||
statement = select(model_class)
|
||||
if filters:
|
||||
conditions = [
|
||||
getattr(model_class, col) == value for col, value in filters.items()]
|
||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||
statement = statement.where(and_(*conditions))
|
||||
|
||||
rows = session.exec(statement).all()
|
||||
|
@ -290,8 +278,7 @@ class DatabaseManager:
|
|||
select(link_table).where(
|
||||
and_(
|
||||
getattr(link_table, primary_id_field) == primary_id,
|
||||
getattr(
|
||||
link_table, secondary_id_field) == secondary_id
|
||||
getattr(link_table, secondary_id_field) == secondary_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
@ -302,37 +289,24 @@ class DatabaseManager:
|
|||
# Get the next sequence number if not provided
|
||||
if sequence is None:
|
||||
max_seq_result = session.exec(
|
||||
select(func.max(link_table.sequence)).where(
|
||||
getattr(link_table, primary_id_field) == primary_id
|
||||
)
|
||||
select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id)
|
||||
).first()
|
||||
sequence = 0 if max_seq_result is None else max_seq_result + 1
|
||||
|
||||
# Create new link
|
||||
new_link = link_table(**{
|
||||
primary_id_field: primary_id,
|
||||
secondary_id_field: secondary_id,
|
||||
'sequence': sequence
|
||||
})
|
||||
new_link = link_table(
|
||||
**{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence}
|
||||
)
|
||||
session.add(new_link)
|
||||
session.commit()
|
||||
|
||||
return Response(
|
||||
message=f"Entities linked successfully with sequence {sequence}",
|
||||
status=True
|
||||
)
|
||||
return Response(message=f"Entities linked successfully with sequence {sequence}", status=True)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
return Response(message=f"Error linking entities: {str(e)}", status=False)
|
||||
|
||||
def unlink(
|
||||
self,
|
||||
link_type: LinkTypes,
|
||||
primary_id: int,
|
||||
secondary_id: int,
|
||||
sequence: Optional[int] = None
|
||||
):
|
||||
def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None):
|
||||
"""Unlink two entities and reorder sequences if needed."""
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
|
@ -349,13 +323,12 @@ class DatabaseManager:
|
|||
statement = select(link_table).where(
|
||||
and_(
|
||||
getattr(link_table, primary_id_field) == primary_id,
|
||||
getattr(link_table, secondary_id_field) == secondary_id
|
||||
getattr(link_table, secondary_id_field) == secondary_id,
|
||||
)
|
||||
)
|
||||
|
||||
if sequence is not None:
|
||||
statement = statement.where(
|
||||
link_table.sequence == sequence)
|
||||
statement = statement.where(link_table.sequence == sequence)
|
||||
|
||||
existing_link = session.exec(statement).first()
|
||||
|
||||
|
@ -379,10 +352,7 @@ class DatabaseManager:
|
|||
|
||||
session.commit()
|
||||
|
||||
return Response(
|
||||
message="Entities unlinked successfully and sequences reordered",
|
||||
status=True
|
||||
)
|
||||
return Response(message="Entities unlinked successfully and sequences reordered", status=True)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
|
@ -414,22 +384,14 @@ class DatabaseManager:
|
|||
.order_by(link_table.sequence)
|
||||
).all()
|
||||
|
||||
result = [
|
||||
item.model_dump() if return_json else item for item in items]
|
||||
result = [item.model_dump() if return_json else item for item in items]
|
||||
|
||||
return Response(
|
||||
message="Linked entities retrieved successfully",
|
||||
status=True,
|
||||
data=result
|
||||
)
|
||||
return Response(message="Linked entities retrieved successfully", status=True, data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting linked entities: {str(e)}")
|
||||
return Response(
|
||||
message=f"Error getting linked entities: {str(e)}",
|
||||
status=False,
|
||||
data=[]
|
||||
)
|
||||
return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[])
|
||||
|
||||
# Add new close method
|
||||
|
||||
async def close(self):
|
||||
|
|
|
@ -1,68 +1,71 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Optional, Tuple, List
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import sqlmodel
|
||||
from alembic import command
|
||||
from alembic.autogenerate import compare_metadata
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from alembic.autogenerate import compare_metadata
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import SQLModel
|
||||
from alembic.util.exc import CommandError
|
||||
from loguru import logger
|
||||
from sqlalchemy import Engine, text
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class SchemaManager:
|
||||
"""
|
||||
Manages database schema validation and migrations using Alembic.
|
||||
Provides automatic schema validation, migrations, and safe upgrades.
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy engine instance
|
||||
auto_upgrade: Whether to automatically upgrade schema when differences found
|
||||
init_mode: Controls initialization behavior:
|
||||
- "none": No automatic initialization (raises error if not set up)
|
||||
- "auto": Initialize if not present (default)
|
||||
- "force": Always reinitialize, removing existing configuration
|
||||
Operations are initiated explicitly by DatabaseManager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
base_dir: Optional[Path] = None,
|
||||
auto_upgrade: bool = True,
|
||||
init_mode: str = "auto"
|
||||
):
|
||||
if init_mode not in ["none", "auto", "force"]:
|
||||
raise ValueError("init_mode must be one of: none, auto, force")
|
||||
"""
|
||||
Initialize configuration only - no filesystem or DB operations.
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy engine instance
|
||||
base_dir: Base directory for Alembic files. If None, uses current working directory
|
||||
"""
|
||||
# Convert string path to Path object if necessary
|
||||
if isinstance(base_dir, str):
|
||||
base_dir = Path(base_dir)
|
||||
|
||||
self.engine = engine
|
||||
self.auto_upgrade = auto_upgrade
|
||||
|
||||
# Use provided base_dir or default to class file location
|
||||
self.base_dir = base_dir or Path(__file__).parent
|
||||
self.alembic_dir = self.base_dir / 'alembic'
|
||||
self.alembic_ini_path = self.base_dir / 'alembic.ini'
|
||||
self.alembic_dir = self.base_dir / "alembic"
|
||||
self.alembic_ini_path = self.base_dir / "alembic.ini"
|
||||
|
||||
# Create base directory if it doesn't exist
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
def initialize_migrations(self, force: bool = False) -> bool:
|
||||
try:
|
||||
if force:
|
||||
logger.info("Force reinitialization of migrations...")
|
||||
self._cleanup_existing_alembic()
|
||||
if not self._initialize_alembic():
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
self._validate_alembic_setup()
|
||||
logger.info("Using existing Alembic configuration")
|
||||
self._update_configuration()
|
||||
except FileNotFoundError:
|
||||
logger.info("Initializing new Alembic configuration")
|
||||
if not self._initialize_alembic():
|
||||
return False
|
||||
|
||||
# Initialize based on mode
|
||||
if init_mode == "force":
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
else:
|
||||
try:
|
||||
self._validate_alembic_setup()
|
||||
logger.info("Using existing Alembic configuration")
|
||||
# Update existing configuration
|
||||
self._update_configuration()
|
||||
except FileNotFoundError:
|
||||
if init_mode == "none":
|
||||
raise
|
||||
logger.info("Initializing new Alembic configuration")
|
||||
self._initialize_alembic()
|
||||
# Only generate initial revision if alembic is properly initialized
|
||||
logger.info("Creating initial migration...")
|
||||
return self.generate_revision("Initial schema") is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize migrations: {e}")
|
||||
return False
|
||||
|
||||
def _update_configuration(self) -> None:
|
||||
"""Updates existing Alembic configuration with current settings."""
|
||||
|
@ -70,11 +73,11 @@ class SchemaManager:
|
|||
|
||||
# Update alembic.ini
|
||||
config_content = self._generate_alembic_ini_content()
|
||||
with open(self.alembic_ini_path, 'w') as f:
|
||||
with open(self.alembic_ini_path, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
# Update env.py
|
||||
env_path = self.alembic_dir / 'env.py'
|
||||
env_path = self.alembic_dir / "env.py"
|
||||
if env_path.exists():
|
||||
self._update_env_py(env_path)
|
||||
else:
|
||||
|
@ -82,37 +85,22 @@ class SchemaManager:
|
|||
|
||||
def _cleanup_existing_alembic(self) -> None:
|
||||
"""
|
||||
Safely removes existing Alembic configuration while preserving versions directory.
|
||||
Completely remove existing Alembic configuration including versions.
|
||||
For fresh initialization, we don't need to preserve anything.
|
||||
"""
|
||||
logger.info(
|
||||
"Cleaning up existing Alembic configuration while preserving versions...")
|
||||
logger.info("Cleaning up existing Alembic configuration...")
|
||||
|
||||
# Create a backup of versions directory if it exists
|
||||
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
|
||||
logger.info("Preserving existing versions directory")
|
||||
|
||||
# Remove alembic directory contents EXCEPT versions
|
||||
# Remove entire alembic directory if it exists
|
||||
if self.alembic_dir.exists():
|
||||
for item in self.alembic_dir.iterdir():
|
||||
if item.name != 'versions':
|
||||
try:
|
||||
if item.is_dir():
|
||||
shutil.rmtree(item)
|
||||
logger.info(f"Removed directory: {item}")
|
||||
else:
|
||||
item.unlink()
|
||||
logger.info(f"Removed file: {item}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove {item}: {e}")
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.alembic_dir)
|
||||
logger.info(f"Removed alembic directory: {self.alembic_dir}")
|
||||
|
||||
# Remove alembic.ini if it exists
|
||||
if self.alembic_ini_path.exists():
|
||||
try:
|
||||
self.alembic_ini_path.unlink()
|
||||
logger.info(
|
||||
f"Removed existing alembic.ini: {self.alembic_ini_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove alembic.ini: {e}")
|
||||
self.alembic_ini_path.unlink()
|
||||
logger.info("Removed alembic.ini")
|
||||
|
||||
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
||||
"""
|
||||
|
@ -124,51 +112,52 @@ class SchemaManager:
|
|||
try:
|
||||
self._validate_alembic_setup()
|
||||
if force:
|
||||
logger.info(
|
||||
"Force initialization requested. Cleaning up existing configuration...")
|
||||
logger.info("Force initialization requested. Cleaning up existing configuration...")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
except FileNotFoundError:
|
||||
logger.info("Alembic configuration not found. Initializing...")
|
||||
if self.alembic_dir.exists():
|
||||
logger.warning(
|
||||
"Found existing alembic directory but missing configuration")
|
||||
logger.warning("Found existing alembic directory but missing configuration")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
logger.info("Alembic initialization complete")
|
||||
|
||||
def _initialize_alembic(self) -> None:
|
||||
logger.info("Initializing Alembic configuration...")
|
||||
|
||||
# Create directories first
|
||||
self.alembic_dir.mkdir(exist_ok=True)
|
||||
versions_dir = self.alembic_dir / 'versions'
|
||||
versions_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create env.py BEFORE running command.init
|
||||
env_path = self.alembic_dir / 'env.py'
|
||||
if not env_path.exists():
|
||||
self._create_minimal_env_py(env_path)
|
||||
logger.info("Created new env.py")
|
||||
|
||||
# Write alembic.ini
|
||||
config_content = self._generate_alembic_ini_content()
|
||||
with open(self.alembic_ini_path, 'w') as f:
|
||||
f.write(config_content)
|
||||
logger.info("Created alembic.ini")
|
||||
|
||||
# Now run alembic init
|
||||
def _initialize_alembic(self) -> bool:
|
||||
"""Initialize alembic structure and configuration"""
|
||||
try:
|
||||
config = self.get_alembic_config()
|
||||
# Ensure parent directory exists
|
||||
self.alembic_dir.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Run alembic init to create fresh directory structure
|
||||
logger.info("Initializing alembic directory structure...")
|
||||
|
||||
# Create initial config file for alembic init
|
||||
config_content = self._generate_alembic_ini_content()
|
||||
with open(self.alembic_ini_path, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
# Use the config we just created
|
||||
config = Config(str(self.alembic_ini_path))
|
||||
command.init(config, str(self.alembic_dir))
|
||||
logger.info("Initialized Alembic directory structure")
|
||||
except CommandError as e:
|
||||
if "already exists" not in str(e):
|
||||
raise
|
||||
|
||||
# Update script template after initialization
|
||||
self.update_script_template()
|
||||
|
||||
# Update env.py with our customizations
|
||||
self._update_env_py(self.alembic_dir / "env.py")
|
||||
|
||||
logger.info("Alembic initialization complete")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Explicitly convert error to string
|
||||
logger.error(f"Failed to initialize alembic: {str(e)}")
|
||||
return False
|
||||
|
||||
def _create_minimal_env_py(self, env_path: Path) -> None:
|
||||
"""Creates a minimal env.py file for Alembic."""
|
||||
content = '''
|
||||
content = """
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
@ -201,7 +190,7 @@ def run_migrations_online() -> None:
|
|||
)
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True
|
||||
)
|
||||
|
@ -211,9 +200,9 @@ def run_migrations_online() -> None:
|
|||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()'''
|
||||
run_migrations_online()"""
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
with open(env_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _generate_alembic_ini_content(self) -> str:
|
||||
|
@ -260,6 +249,29 @@ format = %(levelname)-5.5s [%(name)s] %(message)s
|
|||
datefmt = %H:%M:%S
|
||||
""".strip()
|
||||
|
||||
def update_script_template(self):
|
||||
"""Update the Alembic script template to include SQLModel."""
|
||||
template_path = self.alembic_dir / "script.py.mako"
|
||||
try:
|
||||
with open(template_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# Add sqlmodel import to imports section
|
||||
import_section = "from alembic import op\nimport sqlalchemy as sa"
|
||||
new_imports = "from alembic import op\nimport sqlalchemy as sa\nimport sqlmodel"
|
||||
|
||||
content = content.replace(import_section, new_imports)
|
||||
|
||||
with open(template_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info("Updated script template")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update script template: {e}")
|
||||
return False
|
||||
|
||||
def _update_env_py(self, env_path: Path) -> None:
|
||||
"""
|
||||
Updates the env.py file to use SQLModel metadata.
|
||||
|
@ -268,27 +280,45 @@ datefmt = %H:%M:%S
|
|||
self._create_minimal_env_py(env_path)
|
||||
return
|
||||
try:
|
||||
with open(env_path, 'r') as f:
|
||||
with open(env_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# Add SQLModel import
|
||||
# Add SQLModel import if not present
|
||||
if "from sqlmodel import SQLModel" not in content:
|
||||
content = "from sqlmodel import SQLModel\n" + content
|
||||
|
||||
# Replace target_metadata
|
||||
content = content.replace("target_metadata = None", "target_metadata = SQLModel.metadata")
|
||||
|
||||
# Update both configure blocks properly
|
||||
content = content.replace(
|
||||
"target_metadata = None",
|
||||
"target_metadata = SQLModel.metadata"
|
||||
"""context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)""",
|
||||
"""context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)""",
|
||||
)
|
||||
|
||||
# Add compare_type=True to context.configure
|
||||
if "context.configure(" in content and "compare_type=True" not in content:
|
||||
content = content.replace(
|
||||
"context.configure(",
|
||||
"context.configure(compare_type=True,"
|
||||
)
|
||||
content = content.replace(
|
||||
""" context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)""",
|
||||
""" context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)""",
|
||||
)
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
with open(env_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info("Updated env.py with SQLModel metadata")
|
||||
|
@ -297,6 +327,7 @@ datefmt = %H:%M:%S
|
|||
raise
|
||||
|
||||
# Fixed: use keyword-only argument
|
||||
|
||||
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
||||
"""
|
||||
Ensures Alembic is properly set up, initializing if necessary.
|
||||
|
@ -307,32 +338,24 @@ datefmt = %H:%M:%S
|
|||
try:
|
||||
self._validate_alembic_setup()
|
||||
if force:
|
||||
logger.info(
|
||||
"Force initialization requested. Cleaning up existing configuration...")
|
||||
logger.info("Force initialization requested. Cleaning up existing configuration...")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
except FileNotFoundError:
|
||||
logger.info("Alembic configuration not found. Initializing...")
|
||||
if self.alembic_dir.exists():
|
||||
logger.warning(
|
||||
"Found existing alembic directory but missing configuration")
|
||||
logger.warning("Found existing alembic directory but missing configuration")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
logger.info("Alembic initialization complete")
|
||||
|
||||
def _validate_alembic_setup(self) -> None:
|
||||
"""Validates that Alembic is properly configured."""
|
||||
required_files = [
|
||||
self.alembic_ini_path,
|
||||
self.alembic_dir / 'env.py',
|
||||
self.alembic_dir / 'versions'
|
||||
]
|
||||
required_files = [self.alembic_ini_path, self.alembic_dir / "env.py", self.alembic_dir / "versions"]
|
||||
|
||||
missing = [f for f in required_files if not f.exists()]
|
||||
if missing:
|
||||
raise FileNotFoundError(
|
||||
f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}"
|
||||
)
|
||||
raise FileNotFoundError(f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}")
|
||||
|
||||
def get_alembic_config(self) -> Config:
|
||||
"""
|
||||
|
@ -430,7 +453,7 @@ datefmt = %H:%M:%S
|
|||
|
||||
def check_and_upgrade(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
Checks schema status and upgrades if necessary (and auto_upgrade is True).
|
||||
Checks schema status and upgrades if necessary.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (action_taken, status_message)
|
||||
|
@ -438,13 +461,11 @@ datefmt = %H:%M:%S
|
|||
needs_upgrade, status = self.check_schema_status()
|
||||
|
||||
if needs_upgrade:
|
||||
if self.auto_upgrade:
|
||||
if self.upgrade_schema():
|
||||
return True, "Schema was automatically upgraded"
|
||||
else:
|
||||
return False, "Automatic schema upgrade failed"
|
||||
# Remove the auto_upgrade check since we explicitly called this method
|
||||
if self.upgrade_schema():
|
||||
return True, "Schema was automatically upgraded"
|
||||
else:
|
||||
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
|
||||
return False, "Automatic schema upgrade failed"
|
||||
|
||||
return False, status
|
||||
|
||||
|
@ -460,11 +481,7 @@ datefmt = %H:%M:%S
|
|||
"""
|
||||
try:
|
||||
config = self.get_alembic_config()
|
||||
command.revision(
|
||||
config,
|
||||
message=message,
|
||||
autogenerate=True
|
||||
)
|
||||
command.revision(config, message=message, autogenerate=True)
|
||||
return self.get_head_revision()
|
||||
|
||||
except Exception as e:
|
||||
|
@ -512,26 +529,40 @@ datefmt = %H:%M:%S
|
|||
|
||||
def ensure_schema_up_to_date(self) -> bool:
|
||||
"""
|
||||
Ensures the database schema is up to date, generating and applying migrations if needed.
|
||||
|
||||
Returns:
|
||||
bool: True if schema is up to date or was successfully updated
|
||||
Reset migrations and create fresh migration for current schema state.
|
||||
"""
|
||||
try:
|
||||
# Check for unmigrated changes
|
||||
differences = self.get_schema_differences()
|
||||
if differences:
|
||||
# Generate new migration
|
||||
revision = self.generate_revision("auto-generated")
|
||||
if not revision:
|
||||
return False
|
||||
logger.info(f"Generated new migration: {revision}")
|
||||
logger.info("Resetting migrations and updating to current schema...")
|
||||
|
||||
# Apply any pending migrations
|
||||
upgraded, status = self.check_and_upgrade()
|
||||
if not upgraded and "needs upgrade" in status.lower():
|
||||
# 1. Clear the entire alembic directory
|
||||
if self.alembic_dir.exists():
|
||||
shutil.rmtree(self.alembic_dir)
|
||||
logger.info("Cleared alembic directory")
|
||||
|
||||
# 2. Clear alembic_version table
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||
connection.commit()
|
||||
logger.info("Reset alembic version")
|
||||
|
||||
# 3. Reinitialize alembic from scratch
|
||||
if not self._initialize_alembic():
|
||||
logger.error("Failed to reinitialize alembic")
|
||||
return False
|
||||
|
||||
# 4. Generate fresh migration from current schema
|
||||
revision = self.generate_revision("current_schema")
|
||||
if not revision:
|
||||
logger.error("Failed to generate new migration")
|
||||
return False
|
||||
logger.info(f"Generated fresh migration: {revision}")
|
||||
|
||||
# 5. Apply the migration
|
||||
if not self.upgrade_schema():
|
||||
logger.error("Failed to apply migration")
|
||||
return False
|
||||
logger.info("Successfully applied migration")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,2 +1,11 @@
|
|||
from .db import *
|
||||
from .types import *
|
||||
from .db import Agent, LinkTypes, Message, Model, Run, RunStatus, Session, Team, Tool
|
||||
from .types import (
|
||||
AgentConfig,
|
||||
ComponentConfigInput,
|
||||
MessageConfig,
|
||||
ModelConfig,
|
||||
Response,
|
||||
TeamConfig,
|
||||
TeamResult,
|
||||
ToolConfig,
|
||||
)
|
||||
|
|
|
@ -2,24 +2,28 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union, Tuple, Type
|
||||
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
|
||||
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func, Relationship, SQLModel
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from .types import ToolConfig, ModelConfig, AgentConfig, TeamConfig, MessageConfig, MessageMeta
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
|
||||
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
|
||||
|
||||
from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig
|
||||
|
||||
# added for python3.11 and sqlmodel 0.0.22 incompatibility
|
||||
if hasattr(SQLModel, "model_config"):
|
||||
SQLModel.model_config["protected_namespaces"] = ()
|
||||
elif hasattr(SQLModel, "Config"):
|
||||
|
||||
class CustomSQLModel(SQLModel):
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
SQLModel = CustomSQLModel
|
||||
else:
|
||||
print("Warning: Unable to set protected_namespaces.")
|
||||
logger.warning("Unable to set protected_namespaces.")
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
@ -36,7 +40,7 @@ class ComponentTypes(Enum):
|
|||
ComponentTypes.TEAM: Team,
|
||||
ComponentTypes.AGENT: Agent,
|
||||
ComponentTypes.MODEL: Model,
|
||||
ComponentTypes.TOOL: Tool
|
||||
ComponentTypes.TOOL: Tool,
|
||||
}[self]
|
||||
|
||||
|
||||
|
@ -51,7 +55,7 @@ class LinkTypes(Enum):
|
|||
return {
|
||||
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
|
||||
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
|
||||
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink)
|
||||
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink),
|
||||
}[self]
|
||||
|
||||
@property
|
||||
|
@ -70,40 +74,34 @@ class LinkTypes(Enum):
|
|||
# link models
|
||||
class AgentToolLink(SQLModel, table=True):
|
||||
__table_args__ = (
|
||||
UniqueConstraint('agent_id', 'sequence',
|
||||
name='unique_agent_tool_sequence'),
|
||||
{'sqlite_autoincrement': True}
|
||||
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
agent_id: int = Field(default=None, primary_key=True,
|
||||
foreign_key="agent.id")
|
||||
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
|
||||
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
|
||||
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||||
|
||||
|
||||
class AgentModelLink(SQLModel, table=True):
|
||||
__table_args__ = (
|
||||
UniqueConstraint('agent_id', 'sequence',
|
||||
name='unique_agent_tool_sequence'),
|
||||
{'sqlite_autoincrement': True}
|
||||
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
agent_id: int = Field(default=None, primary_key=True,
|
||||
foreign_key="agent.id")
|
||||
model_id: int = Field(default=None, primary_key=True,
|
||||
foreign_key="model.id")
|
||||
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
|
||||
model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
|
||||
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||||
|
||||
|
||||
class TeamAgentLink(SQLModel, table=True):
|
||||
__table_args__ = (
|
||||
UniqueConstraint('agent_id', 'sequence',
|
||||
name='unique_agent_tool_sequence'),
|
||||
{'sqlite_autoincrement': True}
|
||||
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
|
||||
agent_id: int = Field(default=None, primary_key=True,
|
||||
foreign_key="agent.id")
|
||||
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
|
||||
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||||
|
||||
|
||||
# database models
|
||||
|
||||
|
||||
|
@ -120,10 +118,8 @@ class Tool(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[ToolConfig, dict] = Field(
|
||||
default_factory=ToolConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(
|
||||
back_populates="tools", link_model=AgentToolLink)
|
||||
config: Union[ToolConfig, dict] = Field(default_factory=ToolConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
|
||||
|
||||
|
||||
class Model(SQLModel, table=True):
|
||||
|
@ -139,10 +135,8 @@ class Model(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[ModelConfig, dict] = Field(
|
||||
default_factory=ModelConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(
|
||||
back_populates="models", link_model=AgentModelLink)
|
||||
config: Union[ModelConfig, dict] = Field(default_factory=ModelConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
|
||||
|
||||
|
||||
class Team(SQLModel, table=True):
|
||||
|
@ -158,10 +152,8 @@ class Team(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[TeamConfig, dict] = Field(
|
||||
default_factory=TeamConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(
|
||||
back_populates="teams", link_model=TeamAgentLink)
|
||||
config: Union[TeamConfig, dict] = Field(default_factory=TeamConfig, sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
|
||||
|
||||
|
||||
class Agent(SQLModel, table=True):
|
||||
|
@ -177,14 +169,10 @@ class Agent(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[AgentConfig, dict] = Field(
|
||||
default_factory=AgentConfig, sa_column=Column(JSON))
|
||||
tools: List[Tool] = Relationship(
|
||||
back_populates="agents", link_model=AgentToolLink)
|
||||
models: List[Model] = Relationship(
|
||||
back_populates="agents", link_model=AgentModelLink)
|
||||
teams: List[Team] = Relationship(
|
||||
back_populates="agents", link_model=TeamAgentLink)
|
||||
config: Union[AgentConfig, dict] = Field(default_factory=AgentConfig, sa_column=Column(JSON))
|
||||
tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
|
||||
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
|
||||
teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
|
||||
|
||||
|
||||
class Message(SQLModel, table=True):
|
||||
|
@ -200,17 +188,12 @@ class Message(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[MessageConfig, dict] = Field(
|
||||
default_factory=MessageConfig, sa_column=Column(JSON))
|
||||
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
|
||||
session_id: Optional[int] = Field(
|
||||
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
|
||||
)
|
||||
run_id: Optional[UUID] = Field(
|
||||
default=None, foreign_key="run.id"
|
||||
)
|
||||
|
||||
message_meta: Optional[Union[MessageMeta, dict]] = Field(
|
||||
default={}, sa_column=Column(JSON))
|
||||
run_id: Optional[UUID] = Field(default=None, foreign_key="run.id")
|
||||
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
|
||||
class Session(SQLModel, table=True):
|
||||
|
@ -226,9 +209,7 @@ class Session(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
team_id: Optional[int] = Field(
|
||||
default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE"))
|
||||
)
|
||||
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -242,41 +223,59 @@ class RunStatus(str, Enum):
|
|||
|
||||
class Run(SQLModel, table=True):
|
||||
"""Represents a single execution run within a session"""
|
||||
|
||||
__table_args__ = {"sqlite_autoincrement": True}
|
||||
|
||||
# Primary key using UUID
|
||||
id: UUID = Field(
|
||||
default_factory=uuid4,
|
||||
primary_key=True,
|
||||
index=True
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
|
||||
)
|
||||
session_id: Optional[int] = Field(
|
||||
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
|
||||
)
|
||||
status: RunStatus = Field(default=RunStatus.CREATED)
|
||||
|
||||
# Timestamps using the same pattern as other models
|
||||
# Store the original user task
|
||||
task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
|
||||
|
||||
# Store TeamResult which contains TaskResult
|
||||
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
|
||||
|
||||
error_message: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
class Config:
|
||||
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class GalleryConfig(SQLModel, table=False):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
run: Run
|
||||
team: TeamConfig = None
|
||||
tags: Optional[List[str]] = None
|
||||
visibility: str = "public" # public, private, shared
|
||||
|
||||
class Config:
|
||||
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class Gallery(SQLModel, table=True):
|
||||
__table_args__ = {"sqlite_autoincrement": True}
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.now,
|
||||
sa_column=Column(DateTime(timezone=True), server_default=func.now())
|
||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.now,
|
||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now())
|
||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||||
)
|
||||
|
||||
# Foreign key to Session
|
||||
session_id: Optional[int] = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
ForeignKey("session.id", ondelete="CASCADE"),
|
||||
nullable=False
|
||||
)
|
||||
)
|
||||
|
||||
# Run status and metadata
|
||||
status: RunStatus = Field(default=RunStatus.CREATED)
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Metadata storage following pattern from Message model
|
||||
run_meta: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
# Version tracking like other models
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from pydantic import BaseModel
|
||||
from autogen_agentchat.base._task import TaskResult
|
||||
|
||||
|
||||
class ModelTypes(str, Enum):
|
||||
|
@ -29,9 +29,10 @@ class TerminationTypes(str, Enum):
|
|||
MAX_MESSAGES = "MaxMessageTermination"
|
||||
STOP_MESSAGE = "StopMessageTermination"
|
||||
TEXT_MENTION = "TextMentionTermination"
|
||||
COMBINATION = "CombinationTermination"
|
||||
|
||||
|
||||
class ComponentType(str, Enum):
|
||||
class ComponentTypes(str, Enum):
|
||||
TEAM = "team"
|
||||
AGENT = "agent"
|
||||
MODEL = "model"
|
||||
|
@ -40,11 +41,9 @@ class ComponentType(str, Enum):
|
|||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
model_config = {
|
||||
"protected_namespaces": ()
|
||||
}
|
||||
model_config = {"protected_namespaces": ()}
|
||||
version: str = "1.0.0"
|
||||
component_type: ComponentType
|
||||
component_type: ComponentTypes
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
|
@ -58,7 +57,7 @@ class ModelConfig(BaseConfig):
|
|||
model_type: ModelTypes
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
component_type: ComponentType = ComponentType.MODEL
|
||||
component_type: ComponentTypes = ComponentTypes.MODEL
|
||||
|
||||
|
||||
class ToolConfig(BaseConfig):
|
||||
|
@ -66,7 +65,7 @@ class ToolConfig(BaseConfig):
|
|||
description: str
|
||||
content: str
|
||||
tool_type: ToolTypes
|
||||
component_type: ComponentType = ComponentType.TOOL
|
||||
component_type: ComponentTypes = ComponentTypes.TOOL
|
||||
|
||||
|
||||
class AgentConfig(BaseConfig):
|
||||
|
@ -76,14 +75,18 @@ class AgentConfig(BaseConfig):
|
|||
model_client: Optional[ModelConfig] = None
|
||||
tools: Optional[List[ToolConfig]] = None
|
||||
description: Optional[str] = None
|
||||
component_type: ComponentType = ComponentType.AGENT
|
||||
component_type: ComponentTypes = ComponentTypes.AGENT
|
||||
|
||||
|
||||
class TerminationConfig(BaseConfig):
|
||||
termination_type: TerminationTypes
|
||||
# Fields for basic terminations
|
||||
max_messages: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
component_type: ComponentType = ComponentType.TERMINATION
|
||||
# Fields for combinations
|
||||
operator: Optional[Literal["and", "or"]] = None
|
||||
conditions: Optional[List["TerminationConfig"]] = None
|
||||
component_type: ComponentTypes = ComponentTypes.TERMINATION
|
||||
|
||||
|
||||
class TeamConfig(BaseConfig):
|
||||
|
@ -93,7 +96,7 @@ class TeamConfig(BaseConfig):
|
|||
model_client: Optional[ModelConfig] = None
|
||||
selector_prompt: Optional[str] = None
|
||||
termination_condition: Optional[TerminationConfig] = None
|
||||
component_type: ComponentType = ComponentType.TEAM
|
||||
component_type: ComponentTypes = ComponentTypes.TEAM
|
||||
|
||||
|
||||
class TeamResult(BaseModel):
|
||||
|
@ -111,6 +114,7 @@ class MessageMeta(BaseModel):
|
|||
log: Optional[List[dict]] = None
|
||||
usage: Optional[List[dict]] = None
|
||||
|
||||
|
||||
# web request/response data models
|
||||
|
||||
|
||||
|
@ -126,12 +130,6 @@ class SocketMessage(BaseModel):
|
|||
type: str
|
||||
|
||||
|
||||
ComponentConfig = Union[
|
||||
TeamConfig,
|
||||
AgentConfig,
|
||||
ModelConfig,
|
||||
ToolConfig,
|
||||
TerminationConfig
|
||||
]
|
||||
ComponentConfig = Union[TeamConfig, AgentConfig, ModelConfig, ToolConfig, TerminationConfig]
|
||||
|
||||
ComponentConfigInput = Union[str, Path, dict, ComponentConfig]
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
# metrics - agent_frequency, execution_count, tool_count,
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .datamodel import Message, MessageMeta
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Profiler class to profile agent task runs and compute metrics
|
||||
for performance evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: List[Dict] = []
|
||||
|
||||
def _is_code(self, message: Message) -> bool:
|
||||
"""
|
||||
Check if the message contains code.
|
||||
|
||||
:param message: The message instance to check.
|
||||
:return: True if the message contains code, False otherwise.
|
||||
"""
|
||||
content = message.get("message").get("content").lower()
|
||||
return "```" in content
|
||||
|
||||
def _is_tool(self, message: Message) -> bool:
|
||||
"""
|
||||
Check if the message uses a tool.
|
||||
|
||||
:param message: The message instance to check.
|
||||
:return: True if the message uses a tool, False otherwise.
|
||||
"""
|
||||
content = message.get("message").get("content").lower()
|
||||
return "from skills import" in content
|
||||
|
||||
def _is_code_execution(self, message: Message) -> bool:
|
||||
"""
|
||||
Check if the message indicates code execution.
|
||||
|
||||
:param message: The message instance to check.
|
||||
:return: dict with is_code and status keys.
|
||||
"""
|
||||
content = message.get("message").get("content").lower()
|
||||
if "exitcode:" in content:
|
||||
status = "exitcode: 0" in content
|
||||
return {"is_code": True, "status": status}
|
||||
else:
|
||||
return {"is_code": False, "status": False}
|
||||
|
||||
def _is_terminate(self, message: Message) -> bool:
|
||||
"""
|
||||
Check if the message indicates termination.
|
||||
|
||||
:param message: The message instance to check.
|
||||
:return: True if the message indicates termination, False otherwise.
|
||||
"""
|
||||
content = message.get("message").get("content").lower()
|
||||
return "terminate" in content
|
||||
|
||||
def profile(self, agent_message: Message):
|
||||
"""
|
||||
Profile the agent task run and compute metrics.
|
||||
|
||||
:param agent: The agent instance that ran the task.
|
||||
:param task: The task instance that was run.
|
||||
"""
|
||||
meta = MessageMeta(**agent_message.meta)
|
||||
print(meta.log)
|
||||
usage = meta.usage
|
||||
messages = meta.messages
|
||||
profile = []
|
||||
bar = []
|
||||
stats = {}
|
||||
total_code_executed = 0
|
||||
success_code_executed = 0
|
||||
agents = []
|
||||
for message in messages:
|
||||
agent = message.get("sender")
|
||||
is_code = self._is_code(message)
|
||||
is_tool = self._is_tool(message)
|
||||
is_code_execution = self._is_code_execution(message)
|
||||
total_code_executed += is_code_execution["is_code"]
|
||||
success_code_executed += 1 if is_code_execution["status"] else 0
|
||||
|
||||
row = {
|
||||
"agent": agent,
|
||||
"tool_call": is_code,
|
||||
"code_execution": is_code_execution,
|
||||
"terminate": self._is_terminate(message),
|
||||
}
|
||||
bar_row = {
|
||||
"agent": agent,
|
||||
"tool_call": "tool call" if is_tool else "no tool call",
|
||||
"code_execution": (
|
||||
"success"
|
||||
if is_code_execution["status"]
|
||||
else "failure" if is_code_execution["is_code"] else "no code"
|
||||
),
|
||||
"message": 1,
|
||||
}
|
||||
profile.append(row)
|
||||
bar.append(bar_row)
|
||||
agents.append(agent)
|
||||
code_success_rate = (success_code_executed / total_code_executed if total_code_executed > 0 else 0) * 100
|
||||
stats["code_success_rate"] = code_success_rate
|
||||
stats["total_code_executed"] = total_code_executed
|
||||
return {"profile": profile, "bar": bar, "stats": stats, "agents": set(agents), "usage": usage}
|
|
@ -1,50 +1,39 @@
|
|||
from typing import AsyncGenerator, Callable, Union, Optional
|
||||
import time
|
||||
from .database import ComponentFactory, Component
|
||||
from .datamodel import TeamResult, TaskResult, ComponentConfigInput
|
||||
from autogen_agentchat.messages import ChatMessage, AgentMessage
|
||||
from typing import AsyncGenerator, Callable, Optional, Union
|
||||
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import AgentMessage, ChatMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from .database import Component, ComponentFactory
|
||||
from .datamodel import ComponentConfigInput, TeamResult
|
||||
|
||||
|
||||
class TeamManager:
|
||||
def __init__(self) -> None:
|
||||
self.component_factory = ComponentFactory()
|
||||
|
||||
async def _create_team(
|
||||
self,
|
||||
team_config: ComponentConfigInput,
|
||||
input_func: Optional[Callable] = None
|
||||
) -> Component:
|
||||
async def _create_team(self, team_config: ComponentConfigInput, input_func: Optional[Callable] = None) -> Component:
|
||||
"""Create team instance with common setup logic"""
|
||||
return await self.component_factory.load(
|
||||
team_config,
|
||||
input_func=input_func
|
||||
)
|
||||
return await self.component_factory.load(team_config, input_func=input_func)
|
||||
|
||||
def _create_result(self, task_result: TaskResult, start_time: float) -> TeamResult:
|
||||
"""Create TeamResult with timing info"""
|
||||
return TeamResult(
|
||||
task_result=task_result,
|
||||
usage="",
|
||||
duration=time.time() - start_time
|
||||
)
|
||||
return TeamResult(task_result=task_result, usage="", duration=time.time() - start_time)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
task: str,
|
||||
team_config: ComponentConfigInput,
|
||||
input_func: Optional[Callable] = None,
|
||||
cancellation_token: Optional[CancellationToken] = None
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[AgentMessage, ChatMessage, TaskResult], None]:
|
||||
"""Stream the team's execution results"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
team = await self._create_team(team_config, input_func)
|
||||
stream = team.run_stream(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
stream = team.run_stream(task=task, cancellation_token=cancellation_token)
|
||||
|
||||
async for message in stream:
|
||||
if cancellation_token and cancellation_token.is_cancelled():
|
||||
|
@ -63,15 +52,12 @@ class TeamManager:
|
|||
task: str,
|
||||
team_config: ComponentConfigInput,
|
||||
input_func: Optional[Callable] = None,
|
||||
cancellation_token: Optional[CancellationToken] = None
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> TeamResult:
|
||||
"""Original non-streaming run method with optional cancellation"""
|
||||
start_time = time.time()
|
||||
|
||||
team = await self._create_team(team_config, input_func)
|
||||
result = await team.run(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
result = await team.run(task=task, cancellation_token=cancellation_token)
|
||||
|
||||
return self._create_result(result, start_time)
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .utils import *
|
|
@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple, Union
|
|||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from ..datamodel import Model
|
||||
from ..version import APP_NAME
|
||||
|
||||
|
@ -153,8 +152,7 @@ def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir:
|
|||
for root, dirs, files in os.walk(source_dir):
|
||||
# Update directories and files to exclude those to be ignored
|
||||
dirs[:] = [d for d in dirs if d not in ignore_files]
|
||||
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[
|
||||
1] not in ignore_extensions]
|
||||
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[1] not in ignore_extensions]
|
||||
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
|
@ -163,9 +161,7 @@ def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir:
|
|||
# Verify if the file was modified within the given timestamp range
|
||||
if start_timestamp <= file_mtime <= end_timestamp:
|
||||
file_relative_path = (
|
||||
"files/user" +
|
||||
file_path.split(
|
||||
"files/user", 1)[1] if "files/user" in file_path else ""
|
||||
"files/user" + file_path.split("files/user", 1)[1] if "files/user" in file_path else ""
|
||||
)
|
||||
file_type = get_file_type(file_path)
|
||||
|
||||
|
@ -253,41 +249,27 @@ def sanitize_model(model: Model):
|
|||
model = model.model_dump()
|
||||
valid_keys = ["model", "base_url", "api_key", "api_type", "api_version"]
|
||||
# only add key if value is not None
|
||||
sanitized_model = {k: v for k, v in model.items() if (
|
||||
v is not None and v != "") and k in valid_keys}
|
||||
sanitized_model = {k: v for k, v in model.items() if (v is not None and v != "") and k in valid_keys}
|
||||
return sanitized_model
|
||||
|
||||
|
||||
def test_model(model: Model):
|
||||
"""
|
||||
Test the model endpoint by sending a simple message to the model and returning the response.
|
||||
"""
|
||||
class Version:
|
||||
def __init__(self, ver_str: str):
|
||||
try:
|
||||
# Split into major.minor.patch
|
||||
self.major, self.minor, self.patch = map(int, ver_str.split("."))
|
||||
except (ValueError, AttributeError) as err:
|
||||
raise ValueError(f"Invalid version format: {ver_str}. Expected: major.minor.patch") from err
|
||||
|
||||
print("Testing model", model)
|
||||
def __str__(self):
|
||||
return f"{self.major}.{self.minor}.{self.patch}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = Version(other)
|
||||
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
|
||||
|
||||
# def summarize_chat_history(task: str, messages: List[Dict[str, str]], client: ModelClient):
|
||||
# """
|
||||
# Summarize the chat history using the model endpoint and returning the response.
|
||||
# """
|
||||
# summarization_system_prompt = f"""
|
||||
# You are a helpful assistant that is able to review the chat history between a set of agents (userproxy agents, assistants etc) as they try to address a given TASK and provide a summary. Be SUCCINCT but also comprehensive enough to allow others (who cannot see the chat history) understand and recreate the solution.
|
||||
|
||||
# The task requested by the user is:
|
||||
# ===
|
||||
# {task}
|
||||
# ===
|
||||
# The summary should focus on extracting the actual solution to the task from the chat history (assuming the task was addressed) such that any other agent reading the summary will understand what the actual solution is. Use a neutral tone and DO NOT directly mention the agents. Instead only focus on the actions that were carried out (e.g. do not say 'assistant agent generated some code visualization code ..' instead say say 'visualization code was generated ..'. The answer should be framed as a response to the user task. E.g. if the task is "What is the height of the Eiffel tower", the summary should be "The height of the Eiffel Tower is ...").
|
||||
# """
|
||||
# summarization_prompt = [
|
||||
# {
|
||||
# "role": "system",
|
||||
# "content": summarization_system_prompt,
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"Summarize the following chat history. {str(messages)}",
|
||||
# },
|
||||
# ]
|
||||
# response = client.create(messages=summarization_prompt, cache_seed=None)
|
||||
# return response.choices[0].message.content
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = Version(other)
|
||||
return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
VERSION = "0.4.0.dev37"
|
||||
VERSION = "0.4.0.dev38"
|
||||
__version__ = VERSION
|
||||
APP_NAME = "autogenstudio"
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
# api/app.py
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .routes import sessions, runs, teams, agents, models, tools, ws
|
||||
from .deps import init_managers, cleanup_managers
|
||||
from .config import settings
|
||||
from .initialization import AppInitializer
|
||||
from ..version import VERSION
|
||||
from .config import settings
|
||||
from .deps import cleanup_managers, init_managers
|
||||
from .initialization import AppInitializer
|
||||
from .routes import agents, models, runs, sessions, teams, tools, ws
|
||||
|
||||
# Configure logging
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
@ -54,6 +55,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {str(e)}")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(lifespan=lifespan, debug=True)
|
||||
|
||||
|
@ -143,6 +145,7 @@ async def get_version():
|
|||
"data": {"version": VERSION},
|
||||
}
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
|
||||
|
||||
|
@ -154,6 +157,7 @@ async def health_check():
|
|||
"message": "Service is healthy",
|
||||
}
|
||||
|
||||
|
||||
# Mount static file directories
|
||||
app.mount("/api", api)
|
||||
app.mount(
|
||||
|
@ -172,7 +176,7 @@ async def internal_error_handler(request, exc):
|
|||
return {
|
||||
"status": False,
|
||||
"message": "Internal server error",
|
||||
"detail": str(exc) if settings.API_DOCS else "Internal server error"
|
||||
"detail": str(exc) if settings.API_DOCS else "Internal server error",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
# api/deps.py
|
||||
from typing import Optional
|
||||
from fastapi import Depends, HTTPException, status
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
from ..database import DatabaseManager
|
||||
from .managers.connection import WebSocketManager
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from ..database import ConfigurationManager, DatabaseManager
|
||||
from ..teammanager import TeamManager
|
||||
from .config import settings
|
||||
from ..database import ConfigurationManager
|
||||
from .managers.connection import WebSocketManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,17 +25,16 @@ def get_db_context():
|
|||
"""Provide a transactional scope around a series of operations."""
|
||||
if not _db_manager:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Database manager not initialized"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database manager not initialized"
|
||||
)
|
||||
try:
|
||||
yield _db_manager
|
||||
except Exception as e:
|
||||
logger.error(f"Database operation failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Database operation failed"
|
||||
)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database operation failed"
|
||||
) from e
|
||||
|
||||
|
||||
# Dependency providers
|
||||
|
||||
|
@ -44,8 +43,7 @@ async def get_db() -> DatabaseManager:
|
|||
"""Dependency provider for database manager"""
|
||||
if not _db_manager:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Database manager not initialized"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database manager not initialized"
|
||||
)
|
||||
return _db_manager
|
||||
|
||||
|
@ -54,8 +52,7 @@ async def get_websocket_manager() -> WebSocketManager:
|
|||
"""Dependency provider for connection manager"""
|
||||
if not _websocket_manager:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Connection manager not initialized"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Connection manager not initialized"
|
||||
)
|
||||
return _websocket_manager
|
||||
|
||||
|
@ -63,12 +60,10 @@ async def get_websocket_manager() -> WebSocketManager:
|
|||
async def get_team_manager() -> TeamManager:
|
||||
"""Dependency provider for team manager"""
|
||||
if not _team_manager:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Team manager not initialized"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Team manager not initialized")
|
||||
return _team_manager
|
||||
|
||||
|
||||
# Authentication dependency
|
||||
|
||||
|
||||
|
@ -83,6 +78,7 @@ async def get_current_user(
|
|||
# Implement your user authentication here
|
||||
return "user_id" # Replace with actual user identification
|
||||
|
||||
|
||||
# Manager initialization and cleanup
|
||||
|
||||
|
||||
|
@ -94,20 +90,16 @@ async def init_managers(database_uri: str, config_dir: str, app_root: str) -> No
|
|||
|
||||
try:
|
||||
# Initialize database manager
|
||||
_db_manager = DatabaseManager(
|
||||
engine_uri=database_uri, auto_upgrade=settings.UPGRADE_DATABASE, base_dir=app_root)
|
||||
_db_manager.create_db_and_tables()
|
||||
_db_manager = DatabaseManager(engine_uri=database_uri, base_dir=app_root)
|
||||
_db_manager.initialize_database(auto_upgrade=settings.UPGRADE_DATABASE)
|
||||
|
||||
# init default team config
|
||||
|
||||
_team_config_manager = ConfigurationManager(db_manager=_db_manager)
|
||||
import_result = await _team_config_manager.import_directory(
|
||||
config_dir, settings.DEFAULT_USER_ID, check_exists=True)
|
||||
await _team_config_manager.import_directory(config_dir, settings.DEFAULT_USER_ID, check_exists=True)
|
||||
|
||||
# Initialize connection manager
|
||||
_websocket_manager = WebSocketManager(
|
||||
db_manager=_db_manager
|
||||
)
|
||||
_websocket_manager = WebSocketManager(db_manager=_db_manager)
|
||||
logger.info("Connection manager initialized")
|
||||
|
||||
# Initialize team manager
|
||||
|
@ -149,6 +141,7 @@ async def cleanup_managers() -> None:
|
|||
|
||||
logger.info("All managers cleaned up")
|
||||
|
||||
|
||||
# Utility functions for dependency management
|
||||
|
||||
|
||||
|
@ -157,19 +150,17 @@ def get_manager_status() -> dict:
|
|||
return {
|
||||
"database_manager": _db_manager is not None,
|
||||
"websocket_manager": _websocket_manager is not None,
|
||||
"team_manager": _team_manager is not None
|
||||
"team_manager": _team_manager is not None,
|
||||
}
|
||||
|
||||
|
||||
# Combined dependencies
|
||||
|
||||
|
||||
async def get_managers():
|
||||
"""Get all managers in one dependency"""
|
||||
return {
|
||||
"db": await get_db(),
|
||||
"connection": await get_websocket_manager(),
|
||||
"team": await get_team_manager()
|
||||
}
|
||||
return {"db": await get_db(), "connection": await get_websocket_manager(), "team": await get_team_manager()}
|
||||
|
||||
|
||||
# Error handling for manager operations
|
||||
|
||||
|
@ -183,19 +174,21 @@ class ManagerOperationError(Exception):
|
|||
self.detail = detail
|
||||
super().__init__(f"{manager_name} failed during {operation}: {detail}")
|
||||
|
||||
|
||||
# Dependency for requiring specific managers
|
||||
|
||||
|
||||
def require_managers(*manager_names: str):
|
||||
"""Decorator to require specific managers for a route"""
|
||||
|
||||
async def dependency():
|
||||
status = get_manager_status()
|
||||
missing = [name for name in manager_names if not status.get(
|
||||
f"{name}_manager")]
|
||||
missing = [name for name in manager_names if not status.get(f"{name}_manager")]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Required managers not available: {', '.join(missing)}"
|
||||
detail=f"Required managers not available: {', '.join(missing)}",
|
||||
)
|
||||
return True
|
||||
|
||||
return Depends(dependency)
|
||||
|
|
|
@ -2,15 +2,17 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
|
||||
|
||||
class _AppPaths(BaseModel):
|
||||
"""Internal model representing all application paths"""
|
||||
|
||||
app_root: Path
|
||||
static_root: Path
|
||||
user_files: Path
|
||||
|
@ -47,9 +49,7 @@ class AppInitializer:
|
|||
"""Generate database URI based on settings or environment"""
|
||||
if db_uri := os.getenv("AUTOGENSTUDIO_DATABASE_URI"):
|
||||
return db_uri
|
||||
return self.settings.DATABASE_URI.replace(
|
||||
"./", str(app_root) + "/"
|
||||
)
|
||||
return self.settings.DATABASE_URI.replace("./", str(app_root) + "/")
|
||||
|
||||
def _init_paths(self) -> _AppPaths:
|
||||
"""Initialize and return AppPaths instance"""
|
||||
|
@ -60,14 +60,13 @@ class AppInitializer:
|
|||
user_files=app_root / "files" / "user",
|
||||
ui_root=self._app_path / "ui",
|
||||
config_dir=app_root / self.settings.CONFIG_DIR,
|
||||
database_uri=self._get_database_uri(app_root)
|
||||
database_uri=self._get_database_uri(app_root),
|
||||
)
|
||||
|
||||
def _create_directories(self) -> None:
|
||||
"""Create all required directories"""
|
||||
self.app_root.mkdir(parents=True, exist_ok=True)
|
||||
dirs = [self.static_root, self.user_files,
|
||||
self.ui_root, self.config_dir]
|
||||
dirs = [self.static_root, self.user_files, self.ui_root, self.config_dir]
|
||||
for path in dirs:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import asyncio
|
||||
from autogen_agentchat.base._task import TaskResult
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from typing import Callable, Dict, Optional, Any
|
||||
from uuid import UUID
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from ...datamodel import Run, RunStatus, TeamResult
|
||||
from ...database import DatabaseManager
|
||||
from ...teammanager import TeamManager
|
||||
from autogen_agentchat.base._task import TaskResult
|
||||
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from ...database import DatabaseManager
|
||||
from ...datamodel import Message, MessageConfig, Run, RunStatus, TeamResult
|
||||
from ...teammanager import TeamManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -26,12 +27,20 @@ class WebSocketManager:
|
|||
self._closed_connections: set[UUID] = set()
|
||||
self._input_responses: Dict[UUID, asyncio.Queue] = {}
|
||||
|
||||
self._cancel_message = TeamResult(task_result=TaskResult(messages=[TextMessage(
|
||||
source="user", content="Run cancelled by user")], stop_reason="cancelled by user"), usage="", duration=0).model_dump()
|
||||
self._cancel_message = TeamResult(
|
||||
task_result=TaskResult(
|
||||
messages=[TextMessage(source="user", content="Run cancelled by user")], stop_reason="cancelled by user"
|
||||
),
|
||||
usage="",
|
||||
duration=0,
|
||||
).model_dump()
|
||||
|
||||
def _get_stop_message(self, reason: str) -> dict:
|
||||
return TeamResult(task_result=TaskResult(messages=[TextMessage(
|
||||
source="user", content=reason)], stop_reason=reason), usage="", duration=0).model_dump()
|
||||
return TeamResult(
|
||||
task_result=TaskResult(messages=[TextMessage(source="user", content=reason)], stop_reason=reason),
|
||||
usage="",
|
||||
duration=0,
|
||||
).model_dump()
|
||||
|
||||
async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
|
||||
try:
|
||||
|
@ -41,87 +50,118 @@ class WebSocketManager:
|
|||
# Initialize input queue for this connection
|
||||
self._input_responses[run_id] = asyncio.Queue()
|
||||
|
||||
run = await self._get_run(run_id)
|
||||
if run:
|
||||
run.status = RunStatus.ACTIVE
|
||||
self.db_manager.upsert(run)
|
||||
|
||||
await self._send_message(run_id, {
|
||||
"type": "system",
|
||||
"status": "connected",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
await self._send_message(
|
||||
run_id, {"type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Connection error for run {run_id}: {e}")
|
||||
return False
|
||||
|
||||
async def start_stream(
|
||||
self,
|
||||
run_id: UUID,
|
||||
team_manager: TeamManager,
|
||||
task: str,
|
||||
team_config: dict
|
||||
) -> None:
|
||||
async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None:
|
||||
"""Start streaming task execution with proper run management"""
|
||||
if run_id not in self._connections or run_id in self._closed_connections:
|
||||
raise ValueError(f"No active connection for run {run_id}")
|
||||
|
||||
team_manager = TeamManager()
|
||||
cancellation_token = CancellationToken()
|
||||
self._cancellation_tokens[run_id] = cancellation_token
|
||||
final_result = None
|
||||
|
||||
try:
|
||||
# Create input function for this run
|
||||
# Update run with task and status
|
||||
run = await self._get_run(run_id)
|
||||
if run:
|
||||
run.task = MessageConfig(content=task, source="user").model_dump()
|
||||
run.status = RunStatus.ACTIVE
|
||||
self.db_manager.upsert(run)
|
||||
|
||||
input_func = self.create_input_func(run_id)
|
||||
|
||||
async for message in team_manager.run_stream(
|
||||
task=task,
|
||||
team_config=team_config,
|
||||
input_func=input_func, # Pass the input function
|
||||
cancellation_token=cancellation_token
|
||||
task=task, team_config=team_config, input_func=input_func, cancellation_token=cancellation_token
|
||||
):
|
||||
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
|
||||
logger.info(
|
||||
f"Stream cancelled or connection closed for run {run_id}")
|
||||
logger.info(f"Stream cancelled or connection closed for run {run_id}")
|
||||
break
|
||||
|
||||
formatted_message = self._format_message(message)
|
||||
if formatted_message:
|
||||
await self._send_message(run_id, formatted_message)
|
||||
|
||||
# Save message if it's a content message
|
||||
if isinstance(message, (AgentMessage, ChatMessage)):
|
||||
await self._save_message(run_id, message)
|
||||
# Capture final result if it's a TeamResult
|
||||
elif isinstance(message, TeamResult):
|
||||
final_result = message.model_dump()
|
||||
|
||||
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
|
||||
await self._update_run_status(run_id, RunStatus.COMPLETE)
|
||||
if final_result:
|
||||
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
|
||||
else:
|
||||
logger.warning(f"No final result captured for completed run {run_id}")
|
||||
await self._update_run_status(run_id, RunStatus.COMPLETE)
|
||||
else:
|
||||
await self._send_message(run_id, {
|
||||
"type": "completion",
|
||||
"status": "cancelled",
|
||||
"data": self._cancel_message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
await self._update_run_status(run_id, RunStatus.STOPPED)
|
||||
await self._send_message(
|
||||
run_id,
|
||||
{
|
||||
"type": "completion",
|
||||
"status": "cancelled",
|
||||
"data": self._cancel_message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
# Update run with cancellation result
|
||||
await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error for run {run_id}: {e}")
|
||||
await self._handle_stream_error(run_id, e)
|
||||
|
||||
finally:
|
||||
self._cancellation_tokens.pop(run_id, None)
|
||||
|
||||
async def _save_message(self, run_id: UUID, message: Union[AgentMessage, ChatMessage]) -> None:
|
||||
"""Save a message to the database"""
|
||||
run = await self._get_run(run_id)
|
||||
if run:
|
||||
db_message = Message(
|
||||
session_id=run.session_id,
|
||||
run_id=run_id,
|
||||
config=message.model_dump(),
|
||||
user_id=None, # You might want to pass this from somewhere
|
||||
)
|
||||
self.db_manager.upsert(db_message)
|
||||
|
||||
async def _update_run(
|
||||
self, run_id: UUID, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
|
||||
) -> None:
|
||||
"""Update run status and result"""
|
||||
run = await self._get_run(run_id)
|
||||
if run:
|
||||
run.status = status
|
||||
if team_result:
|
||||
run.team_result = team_result
|
||||
if error:
|
||||
run.error_message = error
|
||||
self.db_manager.upsert(run)
|
||||
|
||||
def create_input_func(self, run_id: UUID) -> Callable:
|
||||
"""Creates an input function for a specific run"""
|
||||
async def input_handler(prompt: str = "") -> str:
|
||||
try:
|
||||
|
||||
async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str:
|
||||
try:
|
||||
# Send input request to client
|
||||
await self._send_message(run_id, {
|
||||
"type": "input_request",
|
||||
"prompt": prompt,
|
||||
"data": {
|
||||
"source": "system",
|
||||
"content": prompt
|
||||
await self._send_message(
|
||||
run_id,
|
||||
{
|
||||
"type": "input_request",
|
||||
"prompt": prompt,
|
||||
"data": {"source": "system", "content": prompt},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
if run_id in self._input_responses:
|
||||
|
@ -141,26 +181,37 @@ class WebSocketManager:
|
|||
if run_id in self._input_responses:
|
||||
await self._input_responses[run_id].put(response)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received input response for inactive run {run_id}")
|
||||
logger.warning(f"Received input response for inactive run {run_id}")
|
||||
|
||||
async def stop_run(self, run_id: UUID, reason: str) -> None:
|
||||
"""Stop a running task"""
|
||||
if run_id in self._cancellation_tokens:
|
||||
logger.info(f"Stopping run {run_id}")
|
||||
# self._cancellation_tokens[run_id].cancel()
|
||||
|
||||
# Send final message if connection still exists and not closed
|
||||
if run_id in self._connections and run_id not in self._closed_connections:
|
||||
try:
|
||||
await self._send_message(run_id, {
|
||||
"type": "completion",
|
||||
"status": "cancelled",
|
||||
"data": self._get_stop_message(reason),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
stop_message = self._get_stop_message(reason)
|
||||
|
||||
try:
|
||||
# Update run record first
|
||||
await self._update_run(run_id, status=RunStatus.STOPPED, team_result=stop_message)
|
||||
|
||||
# Then handle websocket communication if connection is active
|
||||
if run_id in self._connections and run_id not in self._closed_connections:
|
||||
await self._send_message(
|
||||
run_id,
|
||||
{
|
||||
"type": "completion",
|
||||
"status": "cancelled",
|
||||
"data": stop_message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Finally cancel the token
|
||||
self._cancellation_tokens[run_id].cancel()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping run {run_id}: {e}")
|
||||
# We might want to force disconnect here if db update failed
|
||||
# await self.disconnect(run_id) # Optional
|
||||
|
||||
async def disconnect(self, run_id: UUID) -> None:
|
||||
"""Clean up connection and associated resources"""
|
||||
|
@ -185,8 +236,7 @@ class WebSocketManager:
|
|||
message: Message dictionary to send
|
||||
"""
|
||||
if run_id in self._closed_connections:
|
||||
logger.warning(
|
||||
f"Attempted to send message to closed connection for run {run_id}")
|
||||
logger.warning(f"Attempted to send message to closed connection for run {run_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -194,36 +244,36 @@ class WebSocketManager:
|
|||
websocket = self._connections[run_id]
|
||||
await websocket.send_json(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning(
|
||||
f"WebSocket disconnected while sending message for run {run_id}")
|
||||
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
|
||||
await self.disconnect(run_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending message for run {run_id}: {e}, {message}")
|
||||
logger.error(f"Error sending message for run {run_id}: {e}, {message}")
|
||||
# Don't try to send error message here to avoid potential recursive loop
|
||||
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
|
||||
await self.disconnect(run_id)
|
||||
|
||||
async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None:
|
||||
"""Handle stream errors with connection state awareness
|
||||
|
||||
Args:
|
||||
run_id: UUID of the run
|
||||
error: Exception that occurred
|
||||
"""
|
||||
"""Handle stream errors with proper run updates"""
|
||||
if run_id not in self._closed_connections:
|
||||
try:
|
||||
await self._send_message(run_id, {
|
||||
error_result = TeamResult(
|
||||
task_result=TaskResult(
|
||||
messages=[TextMessage(source="system", content=str(error))], stop_reason="error"
|
||||
),
|
||||
usage="",
|
||||
duration=0,
|
||||
).model_dump()
|
||||
|
||||
await self._send_message(
|
||||
run_id,
|
||||
{
|
||||
"type": "completion",
|
||||
"status": "error",
|
||||
"error": str(error),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.error(
|
||||
f"Failed to send error message for run {run_id}: {send_error}")
|
||||
"data": error_result,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
await self._update_run_status(run_id, RunStatus.ERROR, str(error))
|
||||
await self._update_run(run_id, RunStatus.ERROR, team_result=error_result, error=str(error))
|
||||
|
||||
def _format_message(self, message: Any) -> Optional[dict]:
|
||||
"""Format message for WebSocket transmission
|
||||
|
@ -236,10 +286,7 @@ class WebSocketManager:
|
|||
"""
|
||||
try:
|
||||
if isinstance(message, (AgentMessage, ChatMessage)):
|
||||
return {
|
||||
"type": "message",
|
||||
"data": message.model_dump()
|
||||
}
|
||||
return {"type": "message", "data": message.model_dump()}
|
||||
elif isinstance(message, TeamResult):
|
||||
return {
|
||||
"type": "result",
|
||||
|
@ -260,16 +307,10 @@ class WebSocketManager:
|
|||
Returns:
|
||||
Optional[Run]: Run object if found, None otherwise
|
||||
"""
|
||||
response = self.db_manager.get(
|
||||
Run, filters={"id": run_id}, return_json=False)
|
||||
response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False)
|
||||
return response.data[0] if response.status and response.data else None
|
||||
|
||||
async def _update_run_status(
|
||||
self,
|
||||
run_id: UUID,
|
||||
status: RunStatus,
|
||||
error: Optional[str] = None
|
||||
) -> None:
|
||||
async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None:
|
||||
"""Update run status in database
|
||||
|
||||
Args:
|
||||
|
@ -285,14 +326,27 @@ class WebSocketManager:
|
|||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up all active connections and resources when server is shutting down"""
|
||||
logger.info(
|
||||
f"Cleaning up {len(self.active_connections)} active connections")
|
||||
logger.info(f"Cleaning up {len(self.active_connections)} active connections")
|
||||
|
||||
try:
|
||||
# First cancel all running tasks
|
||||
for run_id in self.active_runs.copy():
|
||||
if run_id in self._cancellation_tokens:
|
||||
self._cancellation_tokens[run_id].cancel()
|
||||
run = await self._get_run(run_id)
|
||||
if run and run.status == RunStatus.ACTIVE:
|
||||
interrupted_result = TeamResult(
|
||||
task_result=TaskResult(
|
||||
messages=[TextMessage(source="system", content="Run interrupted by server shutdown")],
|
||||
stop_reason="server_shutdown",
|
||||
),
|
||||
usage="",
|
||||
duration=0,
|
||||
).model_dump()
|
||||
|
||||
run.status = RunStatus.STOPPED
|
||||
run.team_result = interrupted_result
|
||||
self.db_manager.upsert(run)
|
||||
|
||||
# Then disconnect all websockets with timeout
|
||||
# 10 second timeout for entire cleanup
|
||||
|
|
|
@ -1,181 +1,51 @@
|
|||
# api/routes/agents.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict
|
||||
from ..deps import get_db
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from ...database import DatabaseManager # Add this import
|
||||
from ...datamodel import Agent, Model, Tool
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_agents(
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def list_agents(user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
|
||||
"""List all agents for a user"""
|
||||
response = db.get(Agent, filters={"user_id": user_id})
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.get("/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def get_agent(agent_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
|
||||
"""Get a specific agent"""
|
||||
response = db.get(
|
||||
Agent,
|
||||
filters={"id": agent_id, "user_id": user_id}
|
||||
)
|
||||
response = db.get(Agent, filters={"id": agent_id, "user_id": user_id})
|
||||
if not response.status or not response.data:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data[0]
|
||||
}
|
||||
return {"status": True, "data": response.data[0]}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_agent(
|
||||
agent: Agent,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def create_agent(agent: Agent, db: DatabaseManager = Depends(get_db)) -> Dict:
|
||||
"""Create a new agent"""
|
||||
response = db.upsert(agent)
|
||||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.delete("/{agent_id}")
|
||||
async def delete_agent(
|
||||
agent_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def delete_agent(agent_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
|
||||
"""Delete an agent"""
|
||||
response = db.delete(
|
||||
filters={"id": agent_id, "user_id": user_id},
|
||||
model_class=Agent
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Agent deleted successfully"
|
||||
}
|
||||
db.delete(filters={"id": agent_id, "user_id": user_id}, model_class=Agent)
|
||||
return {"status": True, "message": "Agent deleted successfully"}
|
||||
|
||||
|
||||
# Agent-Model link endpoints
|
||||
|
||||
|
||||
@router.post("/{agent_id}/models/{model_id}")
|
||||
async def link_agent_model(
|
||||
agent_id: int,
|
||||
model_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def link_agent_model(agent_id: int, model_id: int, db: DatabaseManager = Depends(get_db)) -> Dict:
|
||||
"""Link a model to an agent"""
|
||||
response = db.link(
|
||||
link_type="agent_model",
|
||||
primary_id=agent_id,
|
||||
secondary_id=model_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Model linked to agent successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/models/{model_id}")
|
||||
async def unlink_agent_model(
|
||||
agent_id: int,
|
||||
model_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Unlink a model from an agent"""
|
||||
response = db.unlink(
|
||||
link_type="agent_model",
|
||||
primary_id=agent_id,
|
||||
secondary_id=model_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Model unlinked from agent successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{agent_id}/models")
|
||||
async def get_agent_models(
|
||||
agent_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Get all models linked to an agent"""
|
||||
response = db.get_linked_entities(
|
||||
link_type="agent_model",
|
||||
primary_id=agent_id,
|
||||
return_json=True
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
|
||||
# Agent-Tool link endpoints
|
||||
|
||||
|
||||
@router.post("/{agent_id}/tools/{tool_id}")
|
||||
async def link_agent_tool(
|
||||
agent_id: int,
|
||||
tool_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Link a tool to an agent"""
|
||||
response = db.link(
|
||||
link_type="agent_tool",
|
||||
primary_id=agent_id,
|
||||
secondary_id=tool_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Tool linked to agent successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/tools/{tool_id}")
|
||||
async def unlink_agent_tool(
|
||||
agent_id: int,
|
||||
tool_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Unlink a tool from an agent"""
|
||||
response = db.unlink(
|
||||
link_type="agent_tool",
|
||||
primary_id=agent_id,
|
||||
secondary_id=tool_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Tool unlinked from agent successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{agent_id}/tools")
|
||||
async def get_agent_tools(
|
||||
agent_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Get all tools linked to an agent"""
|
||||
response = db.get_linked_entities(
|
||||
link_type="agent_tool",
|
||||
primary_id=agent_id,
|
||||
return_json=True
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
db.link(link_type="agent_model", primary_id=agent_id, secondary_id=model_id)
|
||||
return {"status": True, "message": "Model linked to agent successfully"}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# api/routes/gallery.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from ...database import DatabaseManager
|
||||
from ...datamodel import Gallery, GalleryConfig, Response, Run, Session
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_gallery_entry(
|
||||
gallery_data: GalleryConfig, user_id: str, db: DatabaseManager = Depends(get_db)
|
||||
) -> Response:
|
||||
# First validate that user owns all runs
|
||||
for run in gallery_data.runs:
|
||||
run_result = db.get(Run, filters={"id": run.id})
|
||||
if not run_result.status or not run_result.data:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run.id} not found")
|
||||
|
||||
# Get associated session to check ownership
|
||||
session_result = db.get(Session, filters={"id": run_result.data[0].session_id})
|
||||
if not session_result.status or not session_result.data or session_result.data[0].user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail=f"Not authorized to add run {run.id} to gallery")
|
||||
|
||||
# Create gallery entry
|
||||
gallery = Gallery(user_id=user_id, config=gallery_data)
|
||||
result = db.upsert(gallery)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{gallery_id}")
|
||||
async def get_gallery_entry(gallery_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
|
||||
result = db.get(Gallery, filters={"id": gallery_id})
|
||||
if not result.status or not result.data:
|
||||
raise HTTPException(status_code=404, detail="Gallery entry not found")
|
||||
|
||||
gallery = result.data[0]
|
||||
if gallery.config["visibility"] != "public" and gallery.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to view this gallery entry")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_gallery_entries(user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
|
||||
result = db.get(Gallery, filters={"user_id": user_id})
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{gallery_id}")
|
||||
async def delete_gallery_entry(gallery_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
|
||||
# Check ownership first
|
||||
result = db.get(Gallery, filters={"id": gallery_id})
|
||||
if not result.status or not result.data:
|
||||
raise HTTPException(status_code=404, detail="Gallery entry not found")
|
||||
|
||||
if result.data[0].user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this gallery entry")
|
||||
|
||||
# Delete if authorized
|
||||
return db.delete(Gallery, filters={"id": gallery_id})
|
|
@ -1,95 +1,42 @@
|
|||
# api/routes/models.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from openai import OpenAIError
|
||||
from ..deps import get_db
|
||||
|
||||
from ...datamodel import Model
|
||||
from ...utils import test_model
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_models(
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def list_models(user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""List all models for a user"""
|
||||
response = db.get(Model, filters={"user_id": user_id})
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.get("/{model_id}")
|
||||
async def get_model(
|
||||
model_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def get_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get a specific model"""
|
||||
response = db.get(
|
||||
Model,
|
||||
filters={"id": model_id, "user_id": user_id}
|
||||
)
|
||||
response = db.get(Model, filters={"id": model_id, "user_id": user_id})
|
||||
if not response.status or not response.data:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data[0]
|
||||
}
|
||||
return {"status": True, "data": response.data[0]}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_model(
|
||||
model: Model,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def create_model(model: Model, db=Depends(get_db)) -> Dict:
|
||||
"""Create a new model"""
|
||||
response = db.upsert(model)
|
||||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.delete("/{model_id}")
|
||||
async def delete_model(
|
||||
model_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def delete_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Delete a model"""
|
||||
response = db.delete(
|
||||
filters={"id": model_id, "user_id": user_id},
|
||||
model_class=Model
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Model deleted successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_model_endpoint(model: Model) -> Dict:
|
||||
"""Test a model configuration"""
|
||||
try:
|
||||
response = test_model(model)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Model tested successfully",
|
||||
"data": response
|
||||
}
|
||||
except OpenAIError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"OpenAI API error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error testing model: {str(e)}"
|
||||
)
|
||||
db.delete(filters={"id": model_id, "user_id": user_id}, model_class=Model)
|
||||
return {"status": True, "message": "Model deleted successfully"}
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# /api/runs routes
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from uuid import UUID
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from ..deps import get_db, get_websocket_manager, get_team_manager
|
||||
from ...datamodel import Run, Session, Message, Team, RunStatus, MessageConfig
|
||||
|
||||
from ...teammanager import TeamManager
|
||||
from autogen_core.base import CancellationToken
|
||||
from ...datamodel import Message, MessageConfig, Run, RunStatus, Session, Team
|
||||
from ..deps import get_db, get_team_manager, get_websocket_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -23,54 +21,45 @@ async def create_run(
|
|||
request: CreateRunRequest,
|
||||
db=Depends(get_db),
|
||||
) -> Dict:
|
||||
"""Create a new run"""
|
||||
"""Create a new run with initial state"""
|
||||
session_response = db.get(
|
||||
Session,
|
||||
filters={"id": request.session_id, "user_id": request.user_id},
|
||||
return_json=False
|
||||
Session, filters={"id": request.session_id, "user_id": request.user_id}, return_json=False
|
||||
)
|
||||
if not session_response.status or not session_response.data:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
try:
|
||||
|
||||
run = db.upsert(Run(session_id=request.session_id), return_json=False)
|
||||
return {
|
||||
"status": run.status,
|
||||
"data": {"run_id": str(run.data.id)}
|
||||
}
|
||||
|
||||
# }
|
||||
# Create run with default state
|
||||
run = db.upsert(
|
||||
Run(
|
||||
session_id=request.session_id,
|
||||
status=RunStatus.CREATED,
|
||||
task=None, # Will be set when run starts
|
||||
team_result=None,
|
||||
),
|
||||
return_json=False,
|
||||
)
|
||||
return {"status": run.status, "data": {"run_id": str(run.data.id)}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/{run_id}/start")
|
||||
async def start_run(
|
||||
run_id: UUID,
|
||||
message: Message = Body(...),
|
||||
ws_manager=Depends(get_websocket_manager),
|
||||
team_manager=Depends(get_team_manager),
|
||||
db=Depends(get_db),
|
||||
) -> Dict:
|
||||
"""Start streaming task execution"""
|
||||
# We might want to add these endpoints:
|
||||
|
||||
if isinstance(message.config, dict):
|
||||
message.config = MessageConfig(**message.config)
|
||||
|
||||
session = db.get(Session, filters={
|
||||
"id": message.session_id}, return_json=False)
|
||||
@router.get("/{run_id}")
|
||||
async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:
|
||||
"""Get run details including task and result"""
|
||||
run = db.get(Run, filters={"id": run_id}, return_json=False)
|
||||
if not run.status or not run.data:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
team = db.get(
|
||||
Team, filters={"id": session.data[0].team_id}, return_json=False)
|
||||
return {"status": True, "data": run.data[0]}
|
||||
|
||||
try:
|
||||
await ws_manager.start_stream(run_id, team_manager, message.config.content, team.data[0].config)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Stream started successfully",
|
||||
"data": {"run_id": str(run_id)}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@router.get("/{run_id}/messages")
|
||||
async def get_run_messages(run_id: UUID, db=Depends(get_db)) -> Dict:
|
||||
"""Get all messages for a run"""
|
||||
messages = db.get(Message, filters={"run_id": run_id}, order="created_at asc", return_json=False)
|
||||
|
||||
return {"status": True, "data": messages.data}
|
||||
|
|
|
@ -1,72 +1,45 @@
|
|||
# api/routes/sessions.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from ...datamodel import Message, Run, Session
|
||||
from ..deps import get_db
|
||||
from ...datamodel import Session, Message
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_sessions(
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def list_sessions(user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""List all sessions for a user"""
|
||||
response = db.get(Session, filters={"user_id": user_id})
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.get("/{session_id}")
|
||||
async def get_session(
|
||||
session_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def get_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get a specific session"""
|
||||
response = db.get(
|
||||
Session,
|
||||
filters={"id": session_id, "user_id": user_id}
|
||||
)
|
||||
response = db.get(Session, filters={"id": session_id, "user_id": user_id})
|
||||
if not response.status or not response.data:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data[0]
|
||||
}
|
||||
return {"status": True, "data": response.data[0]}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_session(
|
||||
session: Session,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def create_session(session: Session, db=Depends(get_db)) -> Dict:
|
||||
"""Create a new session"""
|
||||
response = db.upsert(session)
|
||||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.put("/{session_id}")
|
||||
async def update_session(
|
||||
session_id: int,
|
||||
user_id: str,
|
||||
session: Session,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def update_session(session_id: int, user_id: str, session: Session, db=Depends(get_db)) -> Dict:
|
||||
"""Update an existing session"""
|
||||
# First verify the session belongs to user
|
||||
existing = db.get(
|
||||
Session,
|
||||
filters={"id": session_id, "user_id": user_id}
|
||||
)
|
||||
existing = db.get(Session, filters={"id": session_id, "user_id": user_id})
|
||||
if not existing.status or not existing.data:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
|
@ -75,40 +48,74 @@ async def update_session(
|
|||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data,
|
||||
"message": "Session updated successfully"
|
||||
}
|
||||
return {"status": True, "data": response.data, "message": "Session updated successfully"}
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def delete_session(
|
||||
session_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def delete_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Delete a session"""
|
||||
response = db.delete(
|
||||
filters={"id": session_id, "user_id": user_id},
|
||||
model_class=Session
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Session deleted successfully"
|
||||
}
|
||||
db.delete(filters={"id": session_id, "user_id": user_id}, model_class=Session)
|
||||
return {"status": True, "message": "Session deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/{session_id}/messages")
|
||||
async def list_messages(
|
||||
session_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""List all messages for a session"""
|
||||
filters = {"session_id": session_id, "user_id": user_id}
|
||||
response = db.get(Message, filters=filters, order="asc")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
@router.get("/{session_id}/runs")
|
||||
async def list_session_runs(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get complete session history organized by runs"""
|
||||
|
||||
try:
|
||||
# 1. Verify session exists and belongs to user
|
||||
session = db.get(Session, filters={"id": session_id, "user_id": user_id}, return_json=False)
|
||||
if not session.status:
|
||||
raise HTTPException(status_code=500, detail="Database error while fetching session")
|
||||
if not session.data:
|
||||
raise HTTPException(status_code=404, detail="Session not found or access denied")
|
||||
|
||||
# 2. Get ordered runs for session
|
||||
runs = db.get(Run, filters={"session_id": session_id}, order="asc", return_json=False)
|
||||
if not runs.status:
|
||||
raise HTTPException(status_code=500, detail="Database error while fetching runs")
|
||||
|
||||
# 3. Build response with messages per run
|
||||
run_data = []
|
||||
if runs.data: # It's ok to have no runs
|
||||
for run in runs.data:
|
||||
try:
|
||||
# Get messages for this specific run
|
||||
messages = db.get(Message, filters={"run_id": run.id}, order="asc", return_json=False)
|
||||
if not messages.status:
|
||||
logger.error(f"Failed to fetch messages for run {run.id}")
|
||||
# Continue processing other runs even if one fails
|
||||
messages.data = []
|
||||
|
||||
run_data.append(
|
||||
{
|
||||
"id": str(run.id),
|
||||
"created_at": run.created_at,
|
||||
"status": run.status,
|
||||
"task": run.task,
|
||||
"team_result": run.team_result,
|
||||
"messages": messages.data or [],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing run {run.id}: {str(e)}")
|
||||
# Include run with error state instead of failing entirely
|
||||
run_data.append(
|
||||
{
|
||||
"id": str(run.id),
|
||||
"created_at": run.created_at,
|
||||
"status": "ERROR",
|
||||
"task": run.task,
|
||||
"team_result": None,
|
||||
"messages": [],
|
||||
"error": f"Failed to process run: {str(e)}",
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": True, "data": {"runs": run_data}}
|
||||
|
||||
except HTTPException:
|
||||
raise # Re-raise HTTP exceptions
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in list_messages: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error while fetching session data") from e
|
||||
|
|
|
@ -1,146 +1,41 @@
|
|||
# api/routes/teams.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict
|
||||
from ..deps import get_db
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from ...datamodel import Team
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_teams(
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""List all teams for a user"""
|
||||
response = db.get(Team, filters={"user_id": user_id})
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.get("/{team_id}")
|
||||
async def get_team(
|
||||
team_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def get_team(team_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get a specific team"""
|
||||
response = db.get(
|
||||
Team,
|
||||
filters={"id": team_id, "user_id": user_id}
|
||||
)
|
||||
response = db.get(Team, filters={"id": team_id, "user_id": user_id})
|
||||
if not response.status or not response.data:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data[0]
|
||||
}
|
||||
return {"status": True, "data": response.data[0]}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_team(
|
||||
team: Team,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def create_team(team: Team, db=Depends(get_db)) -> Dict:
|
||||
"""Create a new team"""
|
||||
response = db.upsert(team)
|
||||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.delete("/{team_id}")
|
||||
async def delete_team(
|
||||
team_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def delete_team(team_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Delete a team"""
|
||||
response = db.delete(
|
||||
filters={"id": team_id, "user_id": user_id},
|
||||
model_class=Team
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Team deleted successfully"
|
||||
}
|
||||
|
||||
# Team-Agent link endpoints
|
||||
|
||||
|
||||
@router.post("/{team_id}/agents/{agent_id}")
|
||||
async def link_team_agent(
|
||||
team_id: int,
|
||||
agent_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Link an agent to a team"""
|
||||
response = db.link(
|
||||
link_type="team_agent",
|
||||
primary_id=team_id,
|
||||
secondary_id=agent_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Agent linked to team successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{team_id}/agents/{agent_id}/{sequence_id}")
|
||||
async def link_team_agent_sequence(
|
||||
team_id: int,
|
||||
agent_id: int,
|
||||
sequence_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Link an agent to a team with sequence"""
|
||||
response = db.link(
|
||||
link_type="team_agent",
|
||||
primary_id=team_id,
|
||||
secondary_id=agent_id,
|
||||
sequence_id=sequence_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Agent linked to team with sequence successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{team_id}/agents/{agent_id}")
|
||||
async def unlink_team_agent(
|
||||
team_id: int,
|
||||
agent_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Unlink an agent from a team"""
|
||||
response = db.unlink(
|
||||
link_type="team_agent",
|
||||
primary_id=team_id,
|
||||
secondary_id=agent_id
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Agent unlinked from team successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{team_id}/agents")
|
||||
async def get_team_agents(
|
||||
team_id: int,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Get all agents linked to a team"""
|
||||
response = db.get_linked_entities(
|
||||
link_type="team_agent",
|
||||
primary_id=team_id,
|
||||
return_json=True
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
db.delete(filters={"id": team_id, "user_id": user_id}, model_class=Team)
|
||||
return {"status": True, "message": "Team deleted successfully"}
|
||||
|
|
|
@ -1,103 +1,41 @@
|
|||
# api/routes/tools.py
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict
|
||||
from ..deps import get_db
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from ...datamodel import Tool
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tools(
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def list_tools(user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""List all tools for a user"""
|
||||
response = db.get(Tool, filters={"user_id": user_id})
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.get("/{tool_id}")
|
||||
async def get_tool(
|
||||
tool_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def get_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get a specific tool"""
|
||||
response = db.get(
|
||||
Tool,
|
||||
filters={"id": tool_id, "user_id": user_id}
|
||||
)
|
||||
response = db.get(Tool, filters={"id": tool_id, "user_id": user_id})
|
||||
if not response.status or not response.data:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data[0]
|
||||
}
|
||||
return {"status": True, "data": response.data[0]}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_tool(
|
||||
tool: Tool,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def create_tool(tool: Tool, db=Depends(get_db)) -> Dict:
|
||||
"""Create a new tool"""
|
||||
response = db.upsert(tool)
|
||||
if not response.status:
|
||||
raise HTTPException(status_code=400, detail=response.message)
|
||||
return {
|
||||
"status": True,
|
||||
"data": response.data
|
||||
}
|
||||
return {"status": True, "data": response.data}
|
||||
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
async def delete_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Delete a tool"""
|
||||
response = db.delete(
|
||||
filters={"id": tool_id, "user_id": user_id},
|
||||
model_class=Tool
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Tool deleted successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{tool_id}/test")
|
||||
async def test_tool(
|
||||
tool_id: int,
|
||||
user_id: str,
|
||||
db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Test a tool configuration"""
|
||||
# Get tool
|
||||
tool_response = db.get(
|
||||
Tool,
|
||||
filters={"id": tool_id, "user_id": user_id}
|
||||
)
|
||||
if not tool_response.status or not tool_response.data:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
tool = tool_response.data[0]
|
||||
|
||||
try:
|
||||
# Implement tool testing logic here
|
||||
# This would depend on the tool type and configuration
|
||||
return {
|
||||
"status": True,
|
||||
"message": "Tool tested successfully",
|
||||
"data": {"tool_id": tool_id}
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error testing tool: {str(e)}"
|
||||
)
|
||||
db.delete(filters={"id": tool_id, "user_id": user_id}, model_class=Tool)
|
||||
return {"status": True, "message": "Tool deleted successfully"}
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
# api/ws.py
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from loguru import logger
|
||||
|
||||
from ..deps import get_websocket_manager, get_db, get_team_manager
|
||||
from ...datamodel import Run, RunStatus
|
||||
from ..deps import get_db, get_websocket_manager
|
||||
from ..managers import WebSocketManager
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.websocket("/runs/{run_id}")
|
||||
|
@ -20,12 +20,12 @@ async def run_websocket(
|
|||
run_id: UUID,
|
||||
ws_manager: WebSocketManager = Depends(get_websocket_manager),
|
||||
db=Depends(get_db),
|
||||
team_manager=Depends(get_team_manager)
|
||||
):
|
||||
"""WebSocket endpoint for run communication"""
|
||||
# Verify run exists and is in valid state
|
||||
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
|
||||
if not run_response.status or not run_response.data:
|
||||
logger.warning(f"Run not found: {run_id}")
|
||||
await websocket.close(code=4004, reason="Run not found")
|
||||
return
|
||||
|
||||
|
@ -48,18 +48,32 @@ async def run_websocket(
|
|||
raw_message = await websocket.receive_text()
|
||||
message = json.loads(raw_message)
|
||||
|
||||
if message.get("type") == "stop":
|
||||
print(f"Received stop request for run {run_id}")
|
||||
reason = message.get(
|
||||
"reason") or "User requested stop/cancellation"
|
||||
if message.get("type") == "start":
|
||||
# Handle start message
|
||||
logger.info(f"Received start request for run {run_id}")
|
||||
task = message.get("task")
|
||||
team_config = message.get("team_config")
|
||||
if task and team_config:
|
||||
# await ws_manager.start_stream(run_id, task, team_config)
|
||||
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
|
||||
else:
|
||||
logger.warning(f"Invalid start message format for run {run_id}")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Invalid start message format",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
elif message.get("type") == "stop":
|
||||
logger.info(f"Received stop request for run {run_id}")
|
||||
reason = message.get("reason") or "User requested stop/cancellation"
|
||||
await ws_manager.stop_run(run_id, reason=reason)
|
||||
break
|
||||
|
||||
elif message.get("type") == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()})
|
||||
|
||||
elif message.get("type") == "input_response":
|
||||
# Handle input response from client
|
||||
|
@ -67,16 +81,13 @@ async def run_websocket(
|
|||
if response is not None:
|
||||
await ws_manager.handle_input_response(run_id, response)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid input response format for run {run_id}")
|
||||
logger.warning(f"Invalid input response format for run {run_id}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received: {raw_message}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"error": "Invalid message format",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
await websocket.send_json(
|
||||
{"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected for run {run_id}")
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
"typecheck": "tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ant-design/charts": "^2.2.3",
|
||||
"@ant-design/plots": "^2.2.2",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@heroicons/react": "^2.0.18",
|
||||
|
@ -38,7 +36,7 @@
|
|||
"gatsby-source-filesystem": "^5.14.0",
|
||||
"gatsby-transformer-sharp": "^5.14.0",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.456.0",
|
||||
"lucide-react": "^0.460.0",
|
||||
"postcss": "^8.4.49",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
|
|
|
@ -105,22 +105,15 @@ export type ThreadStatus =
|
|||
| "timeout";
|
||||
|
||||
export interface WebSocketMessage {
|
||||
type: "message" | "result" | "completion" | "input_request";
|
||||
data?: {
|
||||
source?: string;
|
||||
models_usage?: RequestUsage | null;
|
||||
content?: string;
|
||||
task_result?: TaskResult;
|
||||
};
|
||||
status?: ThreadStatus;
|
||||
type: "message" | "result" | "completion" | "input_request" | "error";
|
||||
data?: AgentMessageConfig | TaskResult;
|
||||
status?: RunStatus;
|
||||
error?: string;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
export interface TaskResult {
|
||||
messages: AgentMessageConfig[];
|
||||
usage: string;
|
||||
duration: number;
|
||||
stop_reason?: string;
|
||||
}
|
||||
|
||||
|
@ -187,3 +180,36 @@ export interface TeamConfig extends BaseConfig {
|
|||
export interface Team extends DBModel {
|
||||
config: TeamConfig;
|
||||
}
|
||||
|
||||
export interface TeamResult {
|
||||
task_result: TaskResult;
|
||||
usage: string;
|
||||
duration: number;
|
||||
}
|
||||
|
||||
export interface Run {
|
||||
id: string;
|
||||
created_at: string;
|
||||
status: RunStatus;
|
||||
task: AgentMessageConfig;
|
||||
team_result: TeamResult | null;
|
||||
messages: Message[]; // Change to Message[]
|
||||
error_message?: string;
|
||||
}
|
||||
|
||||
// Separate transient state
|
||||
interface TransientRunState {
|
||||
pendingInput?: {
|
||||
prompt: string;
|
||||
isPending: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
export type RunStatus =
|
||||
| "created"
|
||||
| "active" // covers 'streaming'
|
||||
| "awaiting_input"
|
||||
| "timeout"
|
||||
| "complete"
|
||||
| "error"
|
||||
| "stopped";
|
||||
|
|
|
@ -10,7 +10,6 @@ import {
|
|||
Node,
|
||||
Edge,
|
||||
Background,
|
||||
Controls,
|
||||
NodeTypes,
|
||||
useReactFlow,
|
||||
ReactFlowProvider,
|
||||
|
@ -24,16 +23,16 @@ import {
|
|||
AgentMessageConfig,
|
||||
AgentConfig,
|
||||
TeamConfig,
|
||||
Run,
|
||||
} from "../../../../types/datamodel";
|
||||
import { ThreadState } from "../types";
|
||||
import { CustomEdge } from "./edge";
|
||||
import { CustomEdge, CustomEdgeData } from "./edge";
|
||||
import { useConfigStore } from "../../../../../hooks/store";
|
||||
import { AgentFlowToolbar } from "./toolbar";
|
||||
import { EdgeMessageModal } from "./edgemessagemodal";
|
||||
|
||||
interface AgentFlowProps {
|
||||
teamConfig: TeamConfig;
|
||||
messages: AgentMessageConfig[];
|
||||
threadState: ThreadState;
|
||||
run: Run;
|
||||
}
|
||||
|
||||
interface MessageSequence {
|
||||
|
@ -51,40 +50,37 @@ interface BidirectionalPattern {
|
|||
|
||||
const NODE_DIMENSIONS = {
|
||||
default: { width: 170, height: 100 },
|
||||
end: { width: 120, height: 80 },
|
||||
end: { width: 170, height: 80 },
|
||||
task: { width: 170, height: 100 },
|
||||
};
|
||||
|
||||
const getLayoutedElements = (
|
||||
nodes: Node[],
|
||||
edges: Edge[],
|
||||
edges: CustomEdge[],
|
||||
direction: "TB" | "LR"
|
||||
) => {
|
||||
const g = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));
|
||||
|
||||
// Updated graph settings
|
||||
g.setGraph({
|
||||
rankdir: direction,
|
||||
nodesep: 80,
|
||||
ranksep: 120,
|
||||
ranker: "tight-tree",
|
||||
nodesep: 110,
|
||||
ranksep: 100,
|
||||
ranker: "network-simplex",
|
||||
marginx: 30,
|
||||
marginy: 30,
|
||||
});
|
||||
|
||||
// Add nodes (unchanged)
|
||||
nodes.forEach((node) => {
|
||||
const dimensions =
|
||||
node.data.type === "end" ? NODE_DIMENSIONS.end : NODE_DIMENSIONS.default;
|
||||
g.setNode(node.id, { ...node, ...dimensions });
|
||||
});
|
||||
|
||||
// Create a map to track bidirectional edges
|
||||
const bidirectionalPairs = new Map<
|
||||
string,
|
||||
{ source: string; target: string }[]
|
||||
>();
|
||||
|
||||
// First pass - identify bidirectional pairs
|
||||
edges.forEach((edge) => {
|
||||
const forwardKey = `${edge.source}->${edge.target}`;
|
||||
const reverseKey = `${edge.target}->${edge.source}`;
|
||||
|
@ -99,33 +95,19 @@ const getLayoutedElements = (
|
|||
});
|
||||
});
|
||||
|
||||
// Second pass - add edges with weights
|
||||
bidirectionalPairs.forEach((pairs, pairKey) => {
|
||||
if (pairs.length === 2) {
|
||||
// Bidirectional edge
|
||||
const [first, second] = pairs;
|
||||
g.setEdge(first.source, first.target, {
|
||||
weight: 2,
|
||||
minlen: 1,
|
||||
});
|
||||
g.setEdge(second.source, second.target, {
|
||||
weight: 1,
|
||||
minlen: 1,
|
||||
});
|
||||
g.setEdge(first.source, first.target, { weight: 2, minlen: 1 });
|
||||
g.setEdge(second.source, second.target, { weight: 1, minlen: 1 });
|
||||
} else {
|
||||
// Regular edge
|
||||
const edge = pairs[0];
|
||||
g.setEdge(edge.source, edge.target, {
|
||||
weight: 1,
|
||||
minlen: 1,
|
||||
});
|
||||
g.setEdge(edge.source, edge.target, { weight: 1, minlen: 1 });
|
||||
}
|
||||
});
|
||||
|
||||
// Run layout
|
||||
Dagre.layout(g);
|
||||
|
||||
// Position nodes
|
||||
const positionedNodes = nodes.map((node) => {
|
||||
const { x, y } = g.node(node.id);
|
||||
const dimensions =
|
||||
|
@ -139,7 +121,6 @@ const getLayoutedElements = (
|
|||
};
|
||||
});
|
||||
|
||||
// Create a map of node positions for edge calculations
|
||||
const nodePositions = new Map(
|
||||
positionedNodes.map((node) => [
|
||||
node.id,
|
||||
|
@ -150,7 +131,6 @@ const getLayoutedElements = (
|
|||
])
|
||||
);
|
||||
|
||||
// Process edges with positions
|
||||
const positionedEdges = edges.map((edge) => {
|
||||
const sourcePos = nodePositions.get(edge.source)!;
|
||||
const targetPos = nodePositions.get(edge.target)!;
|
||||
|
@ -164,10 +144,7 @@ const getLayoutedElements = (
|
|||
};
|
||||
});
|
||||
|
||||
return {
|
||||
nodes: positionedNodes,
|
||||
edges: positionedEdges,
|
||||
};
|
||||
return { nodes: positionedNodes, edges: positionedEdges };
|
||||
};
|
||||
|
||||
const createNode = (
|
||||
|
@ -175,11 +152,11 @@ const createNode = (
|
|||
type: "user" | "agent" | "end",
|
||||
agentConfig?: AgentConfig,
|
||||
isActive: boolean = false,
|
||||
threadState?: ThreadState
|
||||
run?: Run
|
||||
): Node => {
|
||||
const isStreamingOrWaiting =
|
||||
threadState?.status === "streaming" ||
|
||||
threadState?.status === "awaiting_input";
|
||||
const isProcessing =
|
||||
run?.status === "active" || run?.status === "awaiting_input";
|
||||
|
||||
if (type === "user") {
|
||||
return {
|
||||
id,
|
||||
|
@ -193,7 +170,7 @@ const createNode = (
|
|||
isActive,
|
||||
status: "",
|
||||
reason: "",
|
||||
draggable: !isStreamingOrWaiting,
|
||||
draggable: !isProcessing,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -206,8 +183,8 @@ const createNode = (
|
|||
data: {
|
||||
type: "end",
|
||||
label: "End",
|
||||
status: threadState?.status,
|
||||
reason: threadState?.reason || "",
|
||||
status: run?.status,
|
||||
reason: run?.error_message || "",
|
||||
agentType: "",
|
||||
description: "",
|
||||
isActive: false,
|
||||
|
@ -216,6 +193,23 @@ const createNode = (
|
|||
};
|
||||
}
|
||||
|
||||
// if (type === "task") {
|
||||
// return {
|
||||
// id,
|
||||
// type: "agentNode",
|
||||
// position: { x: 0, y: 0 },
|
||||
// data: {
|
||||
// type: "task",
|
||||
// label: "Task",
|
||||
// description: run?.task.content || "",
|
||||
// isActive: false,
|
||||
// status: null,
|
||||
// reason: null,
|
||||
// draggable: false,
|
||||
// },
|
||||
// };
|
||||
// }
|
||||
|
||||
return {
|
||||
id,
|
||||
type: "agentNode",
|
||||
|
@ -228,7 +222,7 @@ const createNode = (
|
|||
isActive,
|
||||
status: "",
|
||||
reason: "",
|
||||
draggable: !isStreamingOrWaiting,
|
||||
draggable: !isProcessing,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
@ -241,23 +235,28 @@ const edgeTypes = {
|
|||
custom: CustomEdge,
|
||||
};
|
||||
|
||||
const AgentFlow: React.FC<AgentFlowProps> = ({
|
||||
teamConfig,
|
||||
messages,
|
||||
threadState,
|
||||
}) => {
|
||||
const AgentFlow: React.FC<AgentFlowProps> = ({ teamConfig, run }) => {
|
||||
const { fitView } = useReactFlow();
|
||||
const [nodes, setNodes] = useState<Node[]>([]);
|
||||
const [edges, setEdges] = useState<Edge[]>([]);
|
||||
const [edges, setEdges] = useState<CustomEdge[]>([]);
|
||||
const [shouldRefit, setShouldRefit] = useState(false);
|
||||
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||
|
||||
// Get settings from store
|
||||
const { agentFlow: settings, setAgentFlowSettings } = useConfigStore();
|
||||
const { agentFlow: settings } = useConfigStore();
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedEdge, setSelectedEdge] = useState<CustomEdge | null>(null);
|
||||
|
||||
const handleEdgeClick = useCallback((edge: CustomEdge) => {
|
||||
if (!edge.data?.messages) return; // Early return if no data/messages
|
||||
|
||||
setSelectedEdge(edge);
|
||||
setModalOpen(true);
|
||||
}, []);
|
||||
|
||||
const onNodesChange = useCallback((changes: NodeChange[]) => {
|
||||
setNodes((nds) => applyNodeChanges(changes, nds));
|
||||
}, []);
|
||||
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
|
@ -265,36 +264,38 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
const timeoutId = setTimeout(() => {
|
||||
fitView({ padding: 0.2, duration: 200 });
|
||||
setShouldRefit(false);
|
||||
}, 100); // Increased delay slightly
|
||||
}, 100);
|
||||
|
||||
return () => clearTimeout(timeoutId);
|
||||
}
|
||||
}, [shouldRefit, fitView]);
|
||||
|
||||
// Process messages into nodes and edges
|
||||
const processMessages = useCallback(
|
||||
(messages: AgentMessageConfig[]) => {
|
||||
if (messages.length === 0) return { nodes: [], edges: [] };
|
||||
if (!run.task) return { nodes: [], edges: [] };
|
||||
|
||||
const nodeMap = new Map<string, Node>();
|
||||
const transitionCounts = new Map<string, MessageSequence>();
|
||||
const bidirectionalPatterns = new Map<string, BidirectionalPattern>();
|
||||
|
||||
// Process first message source
|
||||
const firstAgentConfig = teamConfig.participants.find(
|
||||
(p) => p.name === messages[0].source
|
||||
);
|
||||
nodeMap.set(
|
||||
messages[0].source,
|
||||
createNode(
|
||||
// Add first message node if it exists
|
||||
if (messages.length > 0) {
|
||||
const firstAgentConfig = teamConfig.participants.find(
|
||||
(p) => p.name === messages[0].source
|
||||
);
|
||||
nodeMap.set(
|
||||
messages[0].source,
|
||||
messages[0].source === "user" ? "user" : "agent",
|
||||
firstAgentConfig,
|
||||
false
|
||||
)
|
||||
);
|
||||
createNode(
|
||||
messages[0].source,
|
||||
messages[0].source === "user" ? "user" : "agent",
|
||||
firstAgentConfig,
|
||||
false,
|
||||
run
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Group messages by transitions
|
||||
// Process message transitions
|
||||
for (let i = 0; i < messages.length - 1; i++) {
|
||||
const currentMsg = messages[i];
|
||||
const nextMsg = messages[i + 1];
|
||||
|
@ -319,7 +320,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
transition.messages.push(currentMsg);
|
||||
}
|
||||
|
||||
// Ensure all nodes are in the nodeMap
|
||||
if (!nodeMap.has(nextMsg.source)) {
|
||||
const agentConfig = teamConfig.participants.find(
|
||||
(p) => p.name === nextMsg.source
|
||||
|
@ -330,13 +330,14 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
nextMsg.source,
|
||||
nextMsg.source === "user" ? "user" : "agent",
|
||||
agentConfig,
|
||||
false
|
||||
false,
|
||||
run
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Identify bidirectional patterns
|
||||
// Process bidirectional patterns
|
||||
transitionCounts.forEach((transition, key) => {
|
||||
const [source, target] = key.split("->");
|
||||
const reverseKey = `${target}->${source}`;
|
||||
|
@ -352,10 +353,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
});
|
||||
|
||||
// Create edges with bidirectional routing
|
||||
const newEdges: Edge[] = [];
|
||||
const newEdges: CustomEdge[] = [];
|
||||
const processedKeys = new Set<string>();
|
||||
|
||||
// Helper function to create edge label based on settings
|
||||
const createEdgeLabel = (transition: MessageSequence) => {
|
||||
if (!settings.showLabels) return "";
|
||||
if (transition.totalTokens > 0) {
|
||||
|
@ -374,7 +374,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
const bidirectionalPattern = bidirectionalPatterns.get(patternKey);
|
||||
|
||||
if (bidirectionalPattern) {
|
||||
// Create paired edges for bidirectional pattern
|
||||
const forwardKey = `${source}->${target}`;
|
||||
const reverseKey = `${target}->${source}`;
|
||||
|
||||
|
@ -386,7 +385,7 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
isSecondary: boolean,
|
||||
edgeId: string,
|
||||
pairedEdgeId: string
|
||||
) => ({
|
||||
): CustomEdge => ({
|
||||
id: edgeId,
|
||||
source: transition.source,
|
||||
target: transition.target,
|
||||
|
@ -421,7 +420,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
processedKeys.add(forwardKey);
|
||||
processedKeys.add(reverseKey);
|
||||
} else {
|
||||
// Handle regular edges (including self-loops)
|
||||
newEdges.push({
|
||||
id: `${transition.source}-${transition.target}-${key}`,
|
||||
source: transition.source,
|
||||
|
@ -432,7 +430,7 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
messages: settings.showMessages ? transition.messages : [],
|
||||
},
|
||||
animated:
|
||||
threadState?.status === "streaming" &&
|
||||
run.status === "active" &&
|
||||
key === Array.from(transitionCounts.keys()).pop(),
|
||||
style: {
|
||||
stroke: "#2563eb",
|
||||
|
@ -442,25 +440,22 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
}
|
||||
});
|
||||
|
||||
// Handle end node logic
|
||||
if (threadState && messages.length > 0) {
|
||||
// Add end node if run is complete/error/stopped
|
||||
if (run && messages.length > 0) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
if (["complete", "error", "cancelled"].includes(threadState.status)) {
|
||||
nodeMap.set(
|
||||
"end",
|
||||
createNode("end", "end", undefined, false, threadState)
|
||||
);
|
||||
if (["complete", "error", "stopped"].includes(run.status)) {
|
||||
nodeMap.set("end", createNode("end", "end", undefined, false, run));
|
||||
|
||||
const edgeColor =
|
||||
{
|
||||
complete: "#2563eb",
|
||||
cancelled: "red",
|
||||
error: "red",
|
||||
streaming: "#2563eb",
|
||||
awaiting_input: "#2563eb",
|
||||
timeout: "red",
|
||||
}[threadState.status] || "#2563eb";
|
||||
const edgeColor = {
|
||||
complete: "#2563eb",
|
||||
stopped: "red",
|
||||
error: "red",
|
||||
active: "#2563eb",
|
||||
awaiting_input: "#2563eb",
|
||||
timeout: "red",
|
||||
created: "#2563eb",
|
||||
}[run.status];
|
||||
|
||||
newEdges.push({
|
||||
id: `${lastMessage.source}-end`,
|
||||
|
@ -481,12 +476,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
}
|
||||
}
|
||||
|
||||
return {
|
||||
nodes: Array.from(nodeMap.values()),
|
||||
edges: newEdges,
|
||||
};
|
||||
return { nodes: Array.from(nodeMap.values()), edges: newEdges };
|
||||
},
|
||||
[teamConfig.participants, threadState, settings]
|
||||
[teamConfig.participants, run, settings]
|
||||
);
|
||||
|
||||
const handleToggleFullscreen = useCallback(() => {
|
||||
|
@ -508,7 +500,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
}, [isFullscreen, handleToggleFullscreen]);
|
||||
|
||||
useEffect(() => {
|
||||
const { nodes: newNodes, edges: newEdges } = processMessages(messages);
|
||||
const { nodes: newNodes, edges: newEdges } = processMessages(
|
||||
run.messages.map((m) => m.config)
|
||||
);
|
||||
const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements(
|
||||
newNodes,
|
||||
newEdges,
|
||||
|
@ -518,17 +512,22 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
setNodes(layoutedNodes);
|
||||
setEdges(layoutedEdges);
|
||||
|
||||
if (messages.length > 0) {
|
||||
if (run.messages.length > 0) {
|
||||
setTimeout(() => {
|
||||
fitView({ padding: 0.2, duration: 200 });
|
||||
}, 50);
|
||||
}
|
||||
}, [messages, processMessages, settings.direction, threadState, fitView]);
|
||||
}, [run.messages, processMessages, settings.direction, run.status, fitView]);
|
||||
|
||||
// Define common ReactFlow props
|
||||
const reactFlowProps = {
|
||||
nodes,
|
||||
edges,
|
||||
edges: edges.map((edge) => ({
|
||||
...edge,
|
||||
data: {
|
||||
...edge.data,
|
||||
onClick: () => handleEdgeClick(edge),
|
||||
},
|
||||
})),
|
||||
nodeTypes,
|
||||
edgeTypes,
|
||||
defaultViewport: { x: 0, y: 0, zoom: 1 },
|
||||
|
@ -538,7 +537,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
proOptions: { hideAttribution: true },
|
||||
};
|
||||
|
||||
// Define common toolbar props
|
||||
const toolbarProps = useMemo(
|
||||
() => ({
|
||||
isFullscreen,
|
||||
|
@ -547,16 +545,16 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
}),
|
||||
[isFullscreen, handleToggleFullscreen, fitView]
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={flowWrapper}
|
||||
className={`transition-all duration-200 ${
|
||||
isFullscreen
|
||||
? "fixed inset-4 z-[9999] shadow" // Modal-like styling
|
||||
? "fixed inset-4 z-[50] shadow"
|
||||
: "w-full h-full min-h-[300px]"
|
||||
} bg-tertiary rounded-lg`}
|
||||
>
|
||||
{/* Backdrop when fullscreen */}
|
||||
{isFullscreen && (
|
||||
<div
|
||||
className="fixed inset-0 -z-10 bg-background/80 backdrop-blur-sm"
|
||||
|
@ -566,11 +564,18 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
|
|||
|
||||
<ReactFlow {...reactFlowProps}>
|
||||
{settings.showGrid && <Background />}
|
||||
{/* <Controls /> */}
|
||||
<div className="absolute top-0 right-0 z-50">
|
||||
<AgentFlowToolbar {...toolbarProps} />
|
||||
</div>
|
||||
</ReactFlow>
|
||||
<EdgeMessageModal
|
||||
open={modalOpen}
|
||||
onClose={() => {
|
||||
setModalOpen(false);
|
||||
setSelectedEdge(null);
|
||||
}}
|
||||
edge={selectedEdge}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
@ -66,7 +66,7 @@ function AgentNode({ data, isConnectable }: AgentNodeProps) {
|
|||
|
||||
if (data.type === "end") {
|
||||
return {
|
||||
wrapper: `relative min-w-[120px] shadow rounded-lg overflow-hidden ${activeStyles}`,
|
||||
wrapper: `relative min-w-[180px] shadow rounded-lg overflow-hidden ${activeStyles}`,
|
||||
border:
|
||||
data.status === "complete"
|
||||
? "var(--accent)"
|
||||
|
@ -77,7 +77,7 @@ function AgentNode({ data, isConnectable }: AgentNodeProps) {
|
|||
}
|
||||
|
||||
return {
|
||||
wrapper: `min-w-[150px] rounded-lg shadow overflow-hidden ${activeStyles}`,
|
||||
wrapper: `min-w-[180px] rounded-lg shadow overflow-hidden ${activeStyles}`,
|
||||
border: undefined,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -1,38 +1,23 @@
|
|||
import React from "react";
|
||||
import { Tooltip } from "antd";
|
||||
import { AgentMessageConfig } from "../../../../types/datamodel";
|
||||
import {
|
||||
Edge,
|
||||
EdgeLabelRenderer,
|
||||
type EdgeProps,
|
||||
getSmoothStepPath,
|
||||
} from "@xyflow/react";
|
||||
import { RenderMessage } from "../rendermessage";
|
||||
|
||||
interface EdgeTooltipContentProps {
|
||||
messages: AgentMessageConfig[];
|
||||
}
|
||||
const EDGE_OFFSET = 140;
|
||||
|
||||
interface CustomEdgeData {
|
||||
export interface CustomEdgeData extends Record<string, unknown> {
|
||||
label?: string;
|
||||
messages: AgentMessageConfig[];
|
||||
routingType?: "primary" | "secondary";
|
||||
bidirectionalPair?: string;
|
||||
onClick?: () => void;
|
||||
}
|
||||
|
||||
const EdgeTooltipContent: React.FC<EdgeTooltipContentProps> = ({
|
||||
messages,
|
||||
}) => {
|
||||
return (
|
||||
<div className="p-2 overflow-auto max-h-[200px] scroll max-w-[350px]">
|
||||
<div className="text-xs mb-2">{messages.length} messages</div>
|
||||
<div className="edge-tooltip">
|
||||
{messages.map((message, index) => (
|
||||
<RenderMessage key={index} message={message} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
export type CustomEdge = Edge<CustomEdgeData>;
|
||||
|
||||
interface CustomEdgeProps extends Omit<EdgeProps, "data"> {
|
||||
data: CustomEdgeData;
|
||||
|
@ -63,27 +48,24 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
|
|||
let labelX = 0;
|
||||
let labelY = 0;
|
||||
|
||||
if (isSelfLoop) {
|
||||
const rightOffset = 120;
|
||||
const verticalOffset = sourceY - targetY;
|
||||
const verticalPadding = 6;
|
||||
const radius = 8;
|
||||
if (data.routingType === "secondary" || isSelfLoop) {
|
||||
// Calculate the midpoint and offset
|
||||
const midY = (sourceY + targetY) / 2;
|
||||
const offset = EDGE_OFFSET;
|
||||
|
||||
// Create path that goes out from output, right, up, left, into input
|
||||
edgePath = `
|
||||
M ${sourceX} ${targetY - verticalPadding}
|
||||
L ${sourceX + rightOffset - radius} ${targetY - verticalPadding}
|
||||
Q ${sourceX + rightOffset} ${targetY - verticalPadding} ${
|
||||
sourceX + rightOffset
|
||||
} ${targetY - verticalPadding + radius}
|
||||
L ${sourceX + rightOffset} ${sourceY + verticalPadding - radius}
|
||||
Q ${sourceX + rightOffset} ${sourceY + verticalPadding} ${
|
||||
sourceX + rightOffset - radius
|
||||
} ${sourceY + verticalPadding}
|
||||
L ${sourceX} ${sourceY + verticalPadding}
|
||||
M ${sourceX},${sourceY}
|
||||
L ${sourceX},${sourceY + 10}
|
||||
L ${sourceX + offset},${sourceY + 10}
|
||||
L ${sourceX + offset},${targetY - 10}
|
||||
L ${targetX},${targetY - 10}
|
||||
L ${targetX},${targetY}
|
||||
`;
|
||||
|
||||
labelX = sourceX + rightOffset + 10;
|
||||
labelY = targetY + verticalOffset / 2;
|
||||
// Set label position
|
||||
labelX = sourceX + offset;
|
||||
labelY = midY;
|
||||
} else {
|
||||
[edgePath, labelX, labelY] = getSmoothStepPath({
|
||||
sourceX,
|
||||
|
@ -98,7 +80,7 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
|
|||
if (!data.routingType || isSelfLoop) return { x, y };
|
||||
|
||||
// Make vertical separation more pronounced
|
||||
const verticalOffset = data.routingType === "secondary" ? -35 : 35; // Increased from 20 to 35
|
||||
const verticalOffset = data.routingType === "secondary" ? -35 : 35;
|
||||
const horizontalOffset = data.routingType === "secondary" ? -25 : 25;
|
||||
|
||||
// Calculate edge angle to determine if it's more horizontal or vertical
|
||||
|
@ -109,7 +91,7 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
|
|||
// Always apply some vertical offset
|
||||
const basePosition = {
|
||||
x: isMoreHorizontal ? x : x + horizontalOffset,
|
||||
y: y + (data.routingType === "secondary" ? -35 : 35), // Always apply vertical offset
|
||||
y: y + (data.routingType === "secondary" ? -20 : 20),
|
||||
};
|
||||
|
||||
return basePosition;
|
||||
|
@ -137,29 +119,23 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
|
|||
position: "absolute",
|
||||
transform: `translate(-50%, -50%) translate(${labelPosition.x}px,${labelPosition.y}px)`,
|
||||
pointerEvents: "all",
|
||||
// Add a slight transition for smooth updates
|
||||
transition: "transform 0.2s ease-in-out",
|
||||
transition: "all 0.2s ease-in-out",
|
||||
}}
|
||||
onClick={data.onClick}
|
||||
>
|
||||
<Tooltip
|
||||
title={
|
||||
data.messages && data.messages.length > 0 ? (
|
||||
<EdgeTooltipContent messages={data.messages} />
|
||||
) : (
|
||||
data?.label
|
||||
)
|
||||
}
|
||||
overlayStyle={{ maxWidth: "none" }}
|
||||
<div
|
||||
className="px-2 py-1 rounded bg-secondary hover:bg-tertiary text-primary
|
||||
cursor-pointer transform hover:scale-110 transition-all
|
||||
flex items-center gap-1"
|
||||
style={{
|
||||
whiteSpace: "nowrap",
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="px-2 py-1 rounded bg-secondary bg-opacity-50 text-primary text-sm"
|
||||
style={{
|
||||
whiteSpace: "nowrap", // Prevent label from wrapping
|
||||
}}
|
||||
>
|
||||
{data.label}
|
||||
</div>
|
||||
</Tooltip>
|
||||
{messageCount > 0 && (
|
||||
<span className="text-xs text-secondary">({messageCount})</span>
|
||||
)}
|
||||
<span className="text-sm">{data.label}</span>
|
||||
</div>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
// edgemessagemodal.tsx
|
||||
import React, { useState, useMemo } from "react";
|
||||
import { Modal, Input } from "antd";
|
||||
import { RenderMessage } from "../rendermessage";
|
||||
import { CustomEdge } from "./edge";
|
||||
|
||||
const { Search } = Input;
|
||||
|
||||
interface EdgeMessageModalProps {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
edge: CustomEdge | null;
|
||||
}
|
||||
|
||||
export const EdgeMessageModal: React.FC<EdgeMessageModalProps> = ({
|
||||
open,
|
||||
onClose,
|
||||
edge,
|
||||
}) => {
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
|
||||
const totalTokens = useMemo(() => {
|
||||
if (!edge?.data?.messages) return 0;
|
||||
return edge.data.messages.reduce((acc, msg) => {
|
||||
const promptTokens = msg.models_usage?.prompt_tokens || 0;
|
||||
const completionTokens = msg.models_usage?.completion_tokens || 0;
|
||||
return acc + promptTokens + completionTokens;
|
||||
}, 0);
|
||||
}, [edge?.data?.messages]);
|
||||
|
||||
const filteredMessages = useMemo(() => {
|
||||
if (!edge?.data?.messages) return [];
|
||||
if (!searchTerm) return edge.data.messages;
|
||||
|
||||
return edge.data.messages.filter(
|
||||
(msg) =>
|
||||
typeof msg.content === "string" &&
|
||||
msg.content.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
}, [edge?.data?.messages, searchTerm]);
|
||||
|
||||
if (!edge) return null;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
<div className="space-y-2">
|
||||
<div className="font-medium">
|
||||
{edge.source} → {edge.target}
|
||||
</div>
|
||||
<div className="text-sm text-secondary flex justify-between">
|
||||
{edge.data && (
|
||||
<span>
|
||||
{edge.data.messages.length} message
|
||||
{`${edge.data.messages.length > 1 ? "s" : ""}`}
|
||||
</span>
|
||||
)}
|
||||
<span>{totalTokens.toLocaleString()} tokens</span>
|
||||
</div>
|
||||
{edge.data && edge.data.messages.length > 0 && (
|
||||
<div className="text-xs py-2 font-normal">
|
||||
{" "}
|
||||
The above represents the number of times the {`${edge.target}`}{" "}
|
||||
node sent a message{" "}
|
||||
<span className="font-semibold underline text-accent">after</span>{" "}
|
||||
the {`${edge.source}`} node.{" "}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
open={open}
|
||||
onCancel={onClose}
|
||||
width={800}
|
||||
footer={null}
|
||||
>
|
||||
<div className="max-h-[70vh] overflow-y-auto space-y-4 scroll pr-2">
|
||||
<Search
|
||||
placeholder="Search message content..."
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
allowClear
|
||||
className="sticky top-0 z-10"
|
||||
/>
|
||||
|
||||
<div className="space-y-4 ">
|
||||
{filteredMessages.map((msg, idx) => (
|
||||
<RenderMessage
|
||||
key={idx}
|
||||
message={msg}
|
||||
isLast={idx === filteredMessages.length - 1}
|
||||
/>
|
||||
))}
|
||||
|
||||
{filteredMessages.length === 0 && (
|
||||
<div className="text-center text-secondary py-8">
|
||||
No messages found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
|
@ -4,200 +4,119 @@ import { getServerUrl } from "../../../utils";
|
|||
import { SessionManager } from "../../shared/session/manager";
|
||||
import { IStatus } from "../../../types/app";
|
||||
import {
|
||||
Run,
|
||||
Message,
|
||||
ThreadStatus,
|
||||
WebSocketMessage,
|
||||
TeamConfig,
|
||||
AgentMessageConfig,
|
||||
RunStatus,
|
||||
TeamResult,
|
||||
} from "../../../types/datamodel";
|
||||
import { useConfigStore } from "../../../../hooks/store";
|
||||
import { appContext } from "../../../../hooks/provider";
|
||||
import ChatInput from "./chatinput";
|
||||
import { ModelUsage, ThreadState, TIMEOUT_CONFIG } from "./types";
|
||||
import { MessageList } from "./messagelist";
|
||||
import TeamManager from "../../shared/team/manager";
|
||||
import { teamAPI } from "../../shared/team/api";
|
||||
import { sessionAPI } from "../../shared/session/api";
|
||||
import RunView from "./runview";
|
||||
import { TIMEOUT_CONFIG } from "./types";
|
||||
|
||||
const logo = require("../../../../images/landing/welcome.svg").default;
|
||||
|
||||
export default function ChatView({
|
||||
initMessages,
|
||||
}: {
|
||||
initMessages: Message[];
|
||||
}) {
|
||||
export default function ChatView() {
|
||||
const serverUrl = getServerUrl();
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [error, setError] = React.useState<IStatus | null>({
|
||||
status: true,
|
||||
message: "All good",
|
||||
});
|
||||
const [messages, setMessages] = React.useState<Message[]>(initMessages);
|
||||
const [threadMessages, setThreadMessages] = React.useState<
|
||||
Record<string, ThreadState>
|
||||
>({});
|
||||
const chatContainerRef = React.useRef<HTMLDivElement>(null);
|
||||
const timeoutRefs = React.useRef<Record<string, NodeJS.Timeout>>({});
|
||||
|
||||
// Core state
|
||||
const [existingRuns, setExistingRuns] = React.useState<Run[]>([]);
|
||||
const [currentRun, setCurrentRun] = React.useState<Run | null>(null);
|
||||
|
||||
const chatContainerRef = React.useRef<HTMLDivElement | null>(null);
|
||||
|
||||
// Context and config
|
||||
const { user } = React.useContext(appContext);
|
||||
const { session, sessions } = useConfigStore();
|
||||
const [activeSockets, setActiveSockets] = React.useState<
|
||||
Record<string, WebSocket>
|
||||
>({});
|
||||
const activeSocketsRef = React.useRef<Record<string, WebSocket>>({});
|
||||
const [activeSocket, setActiveSocket] = React.useState<WebSocket | null>(
|
||||
null
|
||||
);
|
||||
const [teamConfig, setTeamConfig] = React.useState<TeamConfig | null>(null);
|
||||
|
||||
const [teamConfig, setTeamConfig] = React.useState<any>(null);
|
||||
const inputTimeoutRef = React.useRef<NodeJS.Timeout | null>(null);
|
||||
const activeSocketRef = React.useRef<WebSocket | null>(null);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (chatContainerRef.current) {
|
||||
chatContainerRef.current.scrollTo({
|
||||
top: chatContainerRef.current.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
// Create a Message object from AgentMessageConfig
|
||||
const createMessage = (
|
||||
config: AgentMessageConfig,
|
||||
runId: string,
|
||||
sessionId: number
|
||||
): Message => ({
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
config,
|
||||
session_id: sessionId,
|
||||
run_id: runId,
|
||||
user_id: user?.email || undefined,
|
||||
});
|
||||
|
||||
// Load existing runs when session changes
|
||||
const loadSessionRuns = async () => {
|
||||
if (!session?.id || !user?.email) return;
|
||||
|
||||
try {
|
||||
const response = await sessionAPI.getSessionRuns(session.id, user.email);
|
||||
setExistingRuns(response.runs);
|
||||
} catch (error) {
|
||||
console.error("Error loading session runs:", error);
|
||||
message.error("Failed to load chat history");
|
||||
}
|
||||
}, [messages, threadMessages]);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (session && session.team_id && user && user.email) {
|
||||
teamAPI.getTeam(session.team_id, user?.email).then((team) => {
|
||||
if (session?.id) {
|
||||
loadSessionRuns();
|
||||
} else {
|
||||
setExistingRuns([]);
|
||||
setCurrentRun(null);
|
||||
}
|
||||
}, [session?.id]);
|
||||
|
||||
// Load team config
|
||||
React.useEffect(() => {
|
||||
if (session?.team_id && user?.email) {
|
||||
teamAPI.getTeam(session.team_id, user.email).then((team) => {
|
||||
setTeamConfig(team.config);
|
||||
// console.log("Team Config", team.config);
|
||||
});
|
||||
}
|
||||
}, [session]);
|
||||
|
||||
const updateSocket = (runId: string, socket: WebSocket | null) => {
|
||||
if (socket) {
|
||||
activeSocketsRef.current[runId] = socket;
|
||||
setActiveSockets((prev) => ({ ...prev, [runId]: socket }));
|
||||
} else {
|
||||
delete activeSocketsRef.current[runId];
|
||||
setActiveSockets((prev) => {
|
||||
const next = { ...prev };
|
||||
delete next[runId];
|
||||
return next;
|
||||
});
|
||||
}
|
||||
};
|
||||
React.useEffect(() => {
|
||||
setTimeout(() => {
|
||||
if (chatContainerRef.current && existingRuns.length > 0) {
|
||||
// Scroll to bottom to show latest run
|
||||
chatContainerRef.current.scrollTo({
|
||||
top: chatContainerRef.current.scrollHeight,
|
||||
behavior: "auto", // Use 'auto' instead of 'smooth' for initial load
|
||||
});
|
||||
}
|
||||
}, 450);
|
||||
}, [existingRuns.length, currentRun?.messages]);
|
||||
|
||||
// Cleanup socket on unmount
|
||||
React.useEffect(() => {
|
||||
return () => {
|
||||
Object.values(activeSockets).forEach((socket) => socket.close());
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
}
|
||||
activeSocket?.close();
|
||||
};
|
||||
}, [activeSockets]);
|
||||
|
||||
const handleTimeoutForRun = (runId: string) => {
|
||||
const socket = activeSocketsRef.current[runId];
|
||||
if (socket && socket.readyState === WebSocket.OPEN) {
|
||||
// Send stop message to backend, just like when user clicks stop
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
type: "stop",
|
||||
reason: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Update thread state with timeout reason
|
||||
setThreadMessages((prev) => {
|
||||
const currentThread = prev[runId];
|
||||
if (!currentThread) return prev;
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[runId]: {
|
||||
...currentThread,
|
||||
status: "cancelled", // Use existing cancelled status
|
||||
reason: "Input request timed out after 3 minutes",
|
||||
isExpanded: true,
|
||||
inputRequest: currentThread.inputRequest
|
||||
? {
|
||||
prompt: currentThread.inputRequest.prompt,
|
||||
isPending: true,
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
if (timeoutRefs.current[runId]) {
|
||||
clearTimeout(timeoutRefs.current[runId]);
|
||||
delete timeoutRefs.current[runId];
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputResponse = async (runId: string, response: string) => {
|
||||
// Clear timeout when response is received
|
||||
if (timeoutRefs.current[runId]) {
|
||||
clearTimeout(timeoutRefs.current[runId]);
|
||||
delete timeoutRefs.current[runId];
|
||||
}
|
||||
|
||||
if (response === "TIMEOUT") {
|
||||
handleTimeoutForRun(runId);
|
||||
return;
|
||||
}
|
||||
|
||||
const socket = activeSockets[runId];
|
||||
if (socket && socket.readyState === WebSocket.OPEN) {
|
||||
try {
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
type: "input_response",
|
||||
response: response,
|
||||
})
|
||||
);
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
status: "streaming",
|
||||
inputRequest: undefined,
|
||||
},
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error("Error sending input response:", error);
|
||||
message.error("Failed to send response");
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
status: "error",
|
||||
reason: "Failed to send input response",
|
||||
},
|
||||
}));
|
||||
}
|
||||
} else {
|
||||
message.error("Connection lost. Please try again.");
|
||||
}
|
||||
};
|
||||
|
||||
const getBaseUrl = (url: string): string => {
|
||||
try {
|
||||
// Remove protocol (http:// or https://)
|
||||
let baseUrl = url.replace(/(^\w+:|^)\/\//, "");
|
||||
|
||||
// Handle both localhost and production cases
|
||||
if (baseUrl.startsWith("localhost")) {
|
||||
// For localhost, keep the port if it exists
|
||||
baseUrl = baseUrl.replace("/api", "");
|
||||
} else if (baseUrl === "/api") {
|
||||
// For production where url is just '/api'
|
||||
baseUrl = window.location.host;
|
||||
} else {
|
||||
// For other cases, remove '/api' and trailing slash
|
||||
baseUrl = baseUrl.replace("/api", "").replace(/\/$/, "");
|
||||
}
|
||||
|
||||
return baseUrl;
|
||||
} catch (error) {
|
||||
console.error("Error processing server URL:", error);
|
||||
throw new Error("Invalid server URL configuration");
|
||||
}
|
||||
};
|
||||
}, [activeSocket]);
|
||||
|
||||
const createRun = async (sessionId: number): Promise<string> => {
|
||||
const payload = { session_id: sessionId, user_id: user?.email || "" };
|
||||
|
||||
const response = await fetch(`${serverUrl}/runs/`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
|
@ -212,334 +131,343 @@ export default function ChatView({
|
|||
return data.data.run_id;
|
||||
};
|
||||
|
||||
const startRun = async (runId: string, query: string) => {
|
||||
const messagePayload = {
|
||||
user_id: user?.email,
|
||||
session_id: session?.id,
|
||||
config: {
|
||||
content: query,
|
||||
source: "user",
|
||||
},
|
||||
};
|
||||
const handleWebSocketMessage = (message: WebSocketMessage) => {
|
||||
setCurrentRun((current) => {
|
||||
if (!current || !session?.id) return null;
|
||||
|
||||
const response = await fetch(`${serverUrl}/runs/${runId}/start`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(messagePayload),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to start run");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
};
|
||||
|
||||
const connectWebSocket = (runId: string, query: string) => {
|
||||
const baseUrl = getBaseUrl(serverUrl);
|
||||
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const wsUrl = `${wsProtocol}//${baseUrl}/api/ws/runs/${runId}`;
|
||||
|
||||
const socket = new WebSocket(wsUrl);
|
||||
let isClosing = false;
|
||||
|
||||
const clearTimeoutForRun = () => {
|
||||
if (timeoutRefs.current[runId]) {
|
||||
clearTimeout(timeoutRefs.current[runId]);
|
||||
delete timeoutRefs.current[runId];
|
||||
}
|
||||
};
|
||||
|
||||
const closeSocket = () => {
|
||||
if (!isClosing && socket.readyState !== WebSocket.CLOSED) {
|
||||
isClosing = true;
|
||||
socket.close();
|
||||
updateSocket(runId, null);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onopen = async () => {
|
||||
try {
|
||||
updateSocket(runId, socket);
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
messages: [],
|
||||
status: "streaming",
|
||||
isExpanded: true,
|
||||
},
|
||||
}));
|
||||
|
||||
setMessages((prev: Message[]) =>
|
||||
prev.map((msg: Message) => {
|
||||
if (msg.run_id === runId && msg.config.source === "bot") {
|
||||
return {
|
||||
...msg,
|
||||
config: {
|
||||
...msg.config,
|
||||
content: "Starting...",
|
||||
},
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
})
|
||||
);
|
||||
|
||||
await startRun(runId, query);
|
||||
} catch (error) {
|
||||
closeSocket();
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
status: "error",
|
||||
isExpanded: true,
|
||||
},
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
const message: WebSocketMessage = JSON.parse(event.data);
|
||||
console.log("WebSocket message:", message);
|
||||
|
||||
switch (message.type) {
|
||||
case "input_request":
|
||||
clearTimeoutForRun();
|
||||
|
||||
timeoutRefs.current[runId] = setTimeout(() => {
|
||||
handleTimeoutForRun(runId);
|
||||
}, TIMEOUT_CONFIG.DURATION_MS);
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
status: "awaiting_input",
|
||||
inputRequest: {
|
||||
prompt: message.data?.content || "",
|
||||
isPending: false,
|
||||
},
|
||||
},
|
||||
}));
|
||||
break;
|
||||
case "error":
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
inputTimeoutRef.current = null;
|
||||
}
|
||||
if (activeSocket) {
|
||||
activeSocket.close();
|
||||
setActiveSocket(null);
|
||||
activeSocketRef.current = null;
|
||||
}
|
||||
console.log("Error: ", message.error);
|
||||
|
||||
case "message":
|
||||
clearTimeoutForRun();
|
||||
if (!message.data) return current;
|
||||
|
||||
setThreadMessages((prev) => {
|
||||
const currentThread = prev[runId] || {
|
||||
messages: [],
|
||||
status: "streaming",
|
||||
isExpanded: true,
|
||||
};
|
||||
// Create new Message object from websocket data
|
||||
const newMessage = createMessage(
|
||||
message.data as AgentMessageConfig,
|
||||
current.id,
|
||||
session.id
|
||||
);
|
||||
|
||||
const models_usage: ModelUsage | undefined = message.data
|
||||
?.models_usage
|
||||
? {
|
||||
prompt_tokens: message.data.models_usage.prompt_tokens,
|
||||
completion_tokens:
|
||||
message.data.models_usage.completion_tokens,
|
||||
}
|
||||
: undefined;
|
||||
return {
|
||||
...current,
|
||||
messages: [...current.messages, newMessage],
|
||||
};
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[runId]: {
|
||||
...currentThread,
|
||||
messages: [
|
||||
...currentThread.messages,
|
||||
{
|
||||
source: message.data?.source || "",
|
||||
content: message.data?.content || "",
|
||||
models_usage,
|
||||
},
|
||||
],
|
||||
status: "streaming",
|
||||
},
|
||||
};
|
||||
});
|
||||
break;
|
||||
case "input_request":
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
}
|
||||
|
||||
inputTimeoutRef.current = setTimeout(() => {
|
||||
const socket = activeSocketRef.current;
|
||||
console.log("Input timeout", socket);
|
||||
|
||||
if (socket?.readyState === WebSocket.OPEN) {
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
type: "stop",
|
||||
reason: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
|
||||
code: TIMEOUT_CONFIG.WEBSOCKET_CODE,
|
||||
})
|
||||
);
|
||||
setCurrentRun((prev) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
status: "stopped",
|
||||
error_message: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
|
||||
}
|
||||
: null
|
||||
);
|
||||
}
|
||||
}, TIMEOUT_CONFIG.DURATION_MS);
|
||||
|
||||
return {
|
||||
...current,
|
||||
status: "awaiting_input",
|
||||
};
|
||||
case "result":
|
||||
case "completion":
|
||||
clearTimeoutForRun();
|
||||
// When run completes, move it to existingRuns
|
||||
const status: RunStatus =
|
||||
message.status === "complete"
|
||||
? "complete"
|
||||
: message.status === "error"
|
||||
? "error"
|
||||
: "stopped";
|
||||
|
||||
setThreadMessages((prev) => {
|
||||
const currentThread = prev[runId];
|
||||
if (!currentThread) return prev;
|
||||
const isTeamResult = (data: any): data is TeamResult => {
|
||||
return (
|
||||
data &&
|
||||
"task_result" in data &&
|
||||
"usage" in data &&
|
||||
"duration" in data
|
||||
);
|
||||
};
|
||||
|
||||
const status: ThreadStatus = message.status || "complete";
|
||||
const reason =
|
||||
message.data?.task_result?.stop_reason ||
|
||||
(message.error ? `Error: ${message.error}` : undefined);
|
||||
const updatedRun = {
|
||||
...current,
|
||||
status,
|
||||
team_result:
|
||||
message.data && isTeamResult(message.data) ? message.data : null,
|
||||
};
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[runId]: {
|
||||
...currentThread,
|
||||
status,
|
||||
reason,
|
||||
isExpanded: true,
|
||||
finalResult: message.data?.task_result?.messages
|
||||
?.filter((msg: any) => msg.content !== "TERMINATE")
|
||||
.pop(),
|
||||
},
|
||||
};
|
||||
});
|
||||
closeSocket();
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
socket.onclose = (event) => {
|
||||
clearTimeoutForRun();
|
||||
|
||||
if (!isClosing) {
|
||||
updateSocket(runId, null);
|
||||
|
||||
setThreadMessages((prev) => {
|
||||
const thread = prev[runId];
|
||||
if (thread && thread.status === "streaming") {
|
||||
return {
|
||||
...prev,
|
||||
[runId]: {
|
||||
...thread,
|
||||
status:
|
||||
event.code === TIMEOUT_CONFIG.WEBSOCKET_CODE
|
||||
? "timeout"
|
||||
: "complete",
|
||||
reason: event.reason || "Connection closed",
|
||||
},
|
||||
};
|
||||
// Add to existing runs if complete
|
||||
if (status === "complete") {
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
inputTimeoutRef.current = null;
|
||||
}
|
||||
if (activeSocket) {
|
||||
activeSocket.close();
|
||||
setActiveSocket(null);
|
||||
activeSocketRef.current = null;
|
||||
}
|
||||
setExistingRuns((prev) => [...prev, updatedRun]);
|
||||
return null;
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
|
||||
return updatedRun;
|
||||
|
||||
default:
|
||||
return current;
|
||||
}
|
||||
};
|
||||
|
||||
socket.onerror = (error) => {
|
||||
clearTimeoutForRun();
|
||||
|
||||
setThreadMessages((prev) => {
|
||||
const thread = prev[runId];
|
||||
if (!thread) return prev;
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[runId]: {
|
||||
...thread,
|
||||
status: "error",
|
||||
reason: "WebSocket connection error occurred",
|
||||
isExpanded: true,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
closeSocket();
|
||||
};
|
||||
|
||||
return socket;
|
||||
});
|
||||
};
|
||||
|
||||
const cancelRun = async (runId: string) => {
|
||||
const socket = activeSockets[runId];
|
||||
if (socket && socket.readyState === WebSocket.OPEN) {
|
||||
socket.send(
|
||||
JSON.stringify({ type: "stop", reason: "Cancelled by user" })
|
||||
const handleError = (error: any) => {
|
||||
console.error("Error:", error);
|
||||
message.error("Error during request processing");
|
||||
|
||||
setCurrentRun((current) => {
|
||||
if (!current) return null;
|
||||
|
||||
const errorRun = {
|
||||
...current,
|
||||
status: "error" as const,
|
||||
error_message:
|
||||
error instanceof Error ? error.message : "Unknown error occurred",
|
||||
};
|
||||
|
||||
// Add failed run to existing runs
|
||||
setExistingRuns((prev) => [...prev, errorRun]);
|
||||
return null; // Clear current run
|
||||
});
|
||||
|
||||
setError({
|
||||
status: false,
|
||||
message:
|
||||
error instanceof Error ? error.message : "Unknown error occurred",
|
||||
});
|
||||
};
|
||||
|
||||
const handleInputResponse = async (response: string) => {
|
||||
if (!activeSocketRef.current || !currentRun) return;
|
||||
|
||||
if (activeSocketRef.current.readyState !== WebSocket.OPEN) {
|
||||
console.error(
|
||||
"Socket not in OPEN state:",
|
||||
activeSocketRef.current.readyState
|
||||
);
|
||||
handleError(new Error("WebSocket connection not available"));
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear timeout when response received
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
inputTimeoutRef.current = null;
|
||||
}
|
||||
|
||||
try {
|
||||
console.log("Sending input response:", response);
|
||||
activeSocketRef.current.send(
|
||||
JSON.stringify({
|
||||
type: "input_response",
|
||||
response: response,
|
||||
})
|
||||
);
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
status: "cancelled",
|
||||
reason: "Cancelled by user",
|
||||
isExpanded: true,
|
||||
},
|
||||
}));
|
||||
setCurrentRun((current) => {
|
||||
if (!current) return null;
|
||||
return {
|
||||
...current,
|
||||
status: "active",
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
handleError(error);
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up timeouts when component unmounts
|
||||
React.useEffect(() => {
|
||||
return () => {
|
||||
Object.entries(timeoutRefs.current).forEach(([_, timeout]) =>
|
||||
clearTimeout(timeout)
|
||||
const handleCancel = async () => {
|
||||
if (!activeSocketRef.current || !currentRun) return;
|
||||
|
||||
// Clear timeout when manually cancelled
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
inputTimeoutRef.current = null;
|
||||
}
|
||||
try {
|
||||
activeSocketRef.current.send(
|
||||
JSON.stringify({
|
||||
type: "stop",
|
||||
reason: "Cancelled by user",
|
||||
})
|
||||
);
|
||||
timeoutRefs.current = {};
|
||||
};
|
||||
}, []);
|
||||
|
||||
setCurrentRun((current) => {
|
||||
if (!current) return null;
|
||||
return {
|
||||
...current,
|
||||
status: "stopped",
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
handleError(error);
|
||||
}
|
||||
};
|
||||
|
||||
const runTask = async (query: string) => {
|
||||
setError(null);
|
||||
setLoading(true);
|
||||
|
||||
if (!session?.id) {
|
||||
// Add explicit cleanup
|
||||
if (activeSocket) {
|
||||
activeSocket.close();
|
||||
setActiveSocket(null);
|
||||
activeSocketRef.current = null;
|
||||
}
|
||||
if (inputTimeoutRef.current) {
|
||||
clearTimeout(inputTimeoutRef.current);
|
||||
inputTimeoutRef.current = null;
|
||||
}
|
||||
|
||||
if (!session?.id || !teamConfig) {
|
||||
// Add teamConfig check
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let runId: string | null = null;
|
||||
|
||||
try {
|
||||
runId = (await createRun(session.id)) + "";
|
||||
const runId = await createRun(session.id);
|
||||
|
||||
const userMessage: Message = {
|
||||
config: {
|
||||
// Initialize run state BEFORE websocket connection
|
||||
setCurrentRun({
|
||||
id: runId,
|
||||
created_at: new Date().toISOString(),
|
||||
status: "created", // Start with created status
|
||||
messages: [],
|
||||
task: {
|
||||
content: query,
|
||||
source: "user",
|
||||
},
|
||||
session_id: session.id,
|
||||
run_id: runId,
|
||||
};
|
||||
|
||||
const botMessage: Message = {
|
||||
config: {
|
||||
content: "Thinking...",
|
||||
source: "bot",
|
||||
},
|
||||
session_id: session.id,
|
||||
run_id: runId,
|
||||
};
|
||||
|
||||
setMessages((prev) => [...prev, userMessage, botMessage]);
|
||||
connectWebSocket(runId, query); // Now passing query to connectWebSocket
|
||||
} catch (err) {
|
||||
console.error("Error:", err);
|
||||
message.error("Error during request processing");
|
||||
|
||||
if (runId) {
|
||||
if (activeSockets[runId]) {
|
||||
activeSockets[runId].close();
|
||||
}
|
||||
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId!]: {
|
||||
...prev[runId!],
|
||||
status: "error",
|
||||
isExpanded: true,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
setError({
|
||||
status: false,
|
||||
message: err instanceof Error ? err.message : "Unknown error occurred",
|
||||
team_result: null,
|
||||
error_message: undefined,
|
||||
});
|
||||
|
||||
// Setup WebSocket
|
||||
const socket = setupWebSocket(runId, query);
|
||||
setActiveSocket(socket);
|
||||
activeSocketRef.current = socket;
|
||||
} catch (error) {
|
||||
handleError(error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
// session changed
|
||||
if (session) {
|
||||
setMessages([]);
|
||||
setThreadMessages({});
|
||||
const setupWebSocket = (runId: string, query: string): WebSocket => {
|
||||
if (!session || !session.id) {
|
||||
throw new Error("Invalid session configuration");
|
||||
}
|
||||
}, [session]);
|
||||
// Close existing socket if any
|
||||
if (activeSocket?.readyState === WebSocket.OPEN) {
|
||||
activeSocket.close();
|
||||
}
|
||||
|
||||
const baseUrl = getBaseUrl(serverUrl);
|
||||
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const wsUrl = `${wsProtocol}//${baseUrl}/api/ws/runs/${runId}`;
|
||||
|
||||
const socket = new WebSocket(wsUrl);
|
||||
|
||||
// Initialize current run
|
||||
setCurrentRun({
|
||||
id: runId,
|
||||
created_at: new Date().toISOString(),
|
||||
status: "active",
|
||||
task: createMessage(
|
||||
{ content: query, source: "user" },
|
||||
runId,
|
||||
session.id || 0
|
||||
).config,
|
||||
team_result: null,
|
||||
messages: [],
|
||||
error_message: undefined,
|
||||
});
|
||||
|
||||
socket.onopen = () => {
|
||||
// Send start message with teamConfig
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
type: "start",
|
||||
task: query,
|
||||
team_config: teamConfig,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
try {
|
||||
const message = JSON.parse(event.data);
|
||||
handleWebSocketMessage(message);
|
||||
} catch (error) {
|
||||
console.error("WebSocket message parsing error:", error);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onclose = () => {
|
||||
activeSocketRef.current = null;
|
||||
setActiveSocket(null);
|
||||
};
|
||||
|
||||
socket.onerror = (error) => {
|
||||
handleError(error);
|
||||
};
|
||||
|
||||
return socket;
|
||||
};
|
||||
|
||||
// Helper for WebSocket URL
|
||||
const getBaseUrl = (url: string): string => {
|
||||
try {
|
||||
let baseUrl = url.replace(/(^\w+:|^)\/\//, "");
|
||||
if (baseUrl.startsWith("localhost")) {
|
||||
baseUrl = baseUrl.replace("/api", "");
|
||||
} else if (baseUrl === "/api") {
|
||||
baseUrl = window.location.host;
|
||||
} else {
|
||||
baseUrl = baseUrl.replace("/api", "").replace(/\/$/, "");
|
||||
}
|
||||
return baseUrl;
|
||||
} catch (error) {
|
||||
console.error("Error processing server URL:", error);
|
||||
throw new Error("Invalid server URL configuration");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="text-primary h-[calc(100vh-195px)] bg-primary relative rounded flex-1 scroll">
|
||||
|
@ -549,45 +477,62 @@ export default function ChatView({
|
|||
</div>
|
||||
<TeamManager />
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col h-full">
|
||||
<div
|
||||
className="flex-1 overflow-y-auto scroll mt-2 relative min-h-0"
|
||||
ref={chatContainerRef}
|
||||
className="flex-1 overflow-y-auto scroll mt-2 min-h-0 relative"
|
||||
>
|
||||
<MessageList
|
||||
messages={messages}
|
||||
threadMessages={threadMessages}
|
||||
setThreadMessages={setThreadMessages}
|
||||
onRetry={runTask}
|
||||
onCancel={cancelRun}
|
||||
onInputResponse={handleInputResponse} // Add the new prop
|
||||
loading={loading}
|
||||
teamConfig={teamConfig}
|
||||
/>
|
||||
<div id="scroll-gradient" className="scroll-gradient h-8 top-0">
|
||||
{" "}
|
||||
<span className=" inline-block h-6"></span>{" "}
|
||||
</div>
|
||||
{sessions !== null && sessions?.length === 0 ? (
|
||||
<div className="flex h-[calc(100%-100px)] flex-col items-center justify-center w-full">
|
||||
<div className="mt-4 text-sm text-secondary text-center">
|
||||
<img src={logo} alt="Welcome" className="w-72 h-72 mb-4" />
|
||||
Welcome! Create a session to get started!
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{teamConfig && (
|
||||
<>
|
||||
{/* Existing Runs */}
|
||||
{existingRuns.map((run, index) => (
|
||||
<RunView
|
||||
teamConfig={teamConfig}
|
||||
key={run.id + "-review-" + index}
|
||||
run={run}
|
||||
isFirstRun={index === 0}
|
||||
/>
|
||||
))}
|
||||
|
||||
{/* Current Run */}
|
||||
{currentRun && (
|
||||
<RunView
|
||||
run={currentRun}
|
||||
teamConfig={teamConfig}
|
||||
onInputResponse={handleInputResponse}
|
||||
onCancel={handleCancel}
|
||||
isFirstRun={existingRuns.length === 0}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{sessions !== null && sessions?.length === 0 ? (
|
||||
<div className="flex h-[calc(100%-100px)] flex-col items-center justify-center w-full">
|
||||
<div className="mt-4 text-sm text-secondary text-center">
|
||||
<img src={logo} alt="Welcome" className="w-72 h-72 mb-4" />
|
||||
Welcome! Create a session to get started!
|
||||
</div>
|
||||
{session && (
|
||||
<div className="flex-shrink-0">
|
||||
<ChatInput
|
||||
onSubmit={runTask}
|
||||
loading={loading}
|
||||
error={error}
|
||||
disabled={currentRun?.status === "awaiting_input"}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{session && (
|
||||
<div className="flex-shrink-0">
|
||||
<ChatInput
|
||||
onSubmit={runTask}
|
||||
loading={loading}
|
||||
error={error}
|
||||
disabled={Object.values(threadMessages).some(
|
||||
(thread) => thread.status === "awaiting_input"
|
||||
)} // Disable input while waiting for user input
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
@ -12,15 +12,6 @@ import { ThreadState, TIMEOUT_CONFIG } from "./types";
|
|||
import { RenderMessage } from "./rendermessage";
|
||||
import LoadingDots from "../../shared/atoms";
|
||||
|
||||
interface ThreadViewProps {
|
||||
thread: ThreadState;
|
||||
isStreaming: boolean;
|
||||
runId: string;
|
||||
onCancel: (runId: string) => void;
|
||||
onInputResponse: (runId: string, response: string) => void;
|
||||
threadContainerRef: (el: HTMLDivElement | null) => void;
|
||||
}
|
||||
|
||||
interface InputRequestProps {
|
||||
prompt: string;
|
||||
onSubmit: (response: string) => void;
|
||||
|
@ -170,105 +161,4 @@ const InputRequestView: React.FC<InputRequestProps> = ({
|
|||
);
|
||||
};
|
||||
|
||||
const ThreadView: React.FC<ThreadViewProps> = ({
|
||||
thread,
|
||||
isStreaming,
|
||||
runId,
|
||||
onCancel,
|
||||
onInputResponse,
|
||||
threadContainerRef,
|
||||
}) => {
|
||||
const isAwaitingInput = thread.status === "awaiting_input";
|
||||
const isTimedOut = thread.status === "timeout";
|
||||
|
||||
const getStatusIcon = () => {
|
||||
switch (thread.status) {
|
||||
case "streaming":
|
||||
return <Loader2 size={16} className="animate-spin text-accent" />;
|
||||
case "awaiting_input":
|
||||
return <MessageSquare size={16} className="text-accent" />;
|
||||
case "complete":
|
||||
return <CheckCircle size={16} className="text-accent" />;
|
||||
case "error":
|
||||
return <AlertTriangle size={16} className="text-red-500" />;
|
||||
case "timeout":
|
||||
return <Clock size={16} className="text-red-500" />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusText = () => {
|
||||
if (isStreaming) {
|
||||
return (
|
||||
<>
|
||||
<span className="inline-block mr-2">Agents working</span>
|
||||
<LoadingDots size={8} />
|
||||
</>
|
||||
);
|
||||
}
|
||||
if (isAwaitingInput) return "Waiting for your input";
|
||||
if (isTimedOut) return TIMEOUT_CONFIG.DEFAULT_MESSAGE;
|
||||
if (thread.reason)
|
||||
return (
|
||||
<>
|
||||
<span className="font-semibold mr-2">Stop Reason:</span>
|
||||
{thread.reason}
|
||||
</>
|
||||
);
|
||||
return null;
|
||||
};
|
||||
|
||||
const handleTimeout = () => {
|
||||
if (thread.inputRequest) {
|
||||
onInputResponse(runId, "TIMEOUT");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-2 border border-secondary rounded bg-primary">
|
||||
<div className="sticky top-0 z-10 flex bg-primary rounded-t items-center justify-between p-3 border-b border-secondary bg-secondary/10">
|
||||
<div className="text-sm text-primary flex items-center gap-2">
|
||||
{getStatusIcon()}
|
||||
{getStatusText()}
|
||||
</div>
|
||||
{(isStreaming || isAwaitingInput) && (
|
||||
<button
|
||||
onClick={() => onCancel(runId)}
|
||||
className="flex items-center gap-1 px-3 py-1 rounded bg-red-500 hover:bg-red-600 text-white text-xs font-medium transition-colors"
|
||||
>
|
||||
<StopCircle size={12} />
|
||||
<span>Stop</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div
|
||||
ref={threadContainerRef}
|
||||
className="max-h-[400px] overflow-y-auto scroll"
|
||||
>
|
||||
<div className="p-3 space-y-3">
|
||||
{thread.messages.map((threadMsg, threadIndex) => (
|
||||
<div key={`thread-${threadIndex}`}>
|
||||
<RenderMessage
|
||||
message={threadMsg}
|
||||
isLast={threadIndex === thread.messages.length - 1}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{thread.inputRequest && (
|
||||
<InputRequestView
|
||||
prompt={thread.inputRequest.prompt}
|
||||
onSubmit={(response) => onInputResponse(runId, response)}
|
||||
disabled={!isAwaitingInput || isTimedOut}
|
||||
onTimeout={handleTimeout}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ThreadView;
|
||||
export default InputRequestView;
|
|
@ -1,289 +0,0 @@
|
|||
import React from "react";
|
||||
import { ThreadState } from "./types";
|
||||
import {
|
||||
AgentMessageConfig,
|
||||
Message,
|
||||
TeamConfig,
|
||||
} from "../../../types/datamodel";
|
||||
import { RenderMessage } from "./rendermessage";
|
||||
import {
|
||||
StopCircle,
|
||||
User,
|
||||
Network,
|
||||
MessageSquare,
|
||||
Loader2,
|
||||
CheckCircle,
|
||||
AlertTriangle,
|
||||
TriangleAlertIcon,
|
||||
GroupIcon,
|
||||
} from "lucide-react";
|
||||
import AgentFlow from "./agentflow/agentflow";
|
||||
import ThreadView from "./threadview";
|
||||
import LoadingDots from "../../shared/atoms";
|
||||
|
||||
interface MessageListProps {
|
||||
messages: Message[];
|
||||
threadMessages: Record<string, ThreadState>;
|
||||
setThreadMessages: React.Dispatch<
|
||||
React.SetStateAction<Record<string, ThreadState>>
|
||||
>;
|
||||
onRetry: (content: string) => void;
|
||||
onCancel: (runId: string) => void;
|
||||
onInputResponse: (runId: string, response: string) => void;
|
||||
loading?: boolean;
|
||||
teamConfig?: TeamConfig;
|
||||
}
|
||||
|
||||
interface MessagePair {
|
||||
userMessage: Message;
|
||||
botMessage: Message;
|
||||
}
|
||||
|
||||
export const MessageList: React.FC<MessageListProps> = ({
|
||||
messages,
|
||||
threadMessages,
|
||||
setThreadMessages,
|
||||
onRetry,
|
||||
onCancel,
|
||||
onInputResponse, // New prop
|
||||
loading = false,
|
||||
teamConfig,
|
||||
}) => {
|
||||
const messagePairs = React.useMemo(() => {
|
||||
const pairs: MessagePair[] = [];
|
||||
for (let i = 0; i < messages.length; i += 2) {
|
||||
if (messages[i] && messages[i + 1]) {
|
||||
pairs.push({
|
||||
userMessage: messages[i],
|
||||
botMessage: messages[i + 1],
|
||||
});
|
||||
}
|
||||
}
|
||||
return pairs;
|
||||
}, [messages]);
|
||||
|
||||
// Create a ref map to store refs for each thread container
|
||||
const threadContainerRefs = React.useRef<
|
||||
Record<string, HTMLDivElement | null>
|
||||
>({});
|
||||
|
||||
// Effect to handle scrolling when thread messages update
|
||||
React.useEffect(() => {
|
||||
Object.entries(threadMessages).forEach(([runId, thread]) => {
|
||||
if (thread.isExpanded && threadContainerRefs.current[runId]) {
|
||||
const container = threadContainerRefs.current[runId];
|
||||
if (container) {
|
||||
container.scrollTo({
|
||||
top: container.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}, [threadMessages]);
|
||||
|
||||
const toggleThread = (runId: string) => {
|
||||
setThreadMessages((prev) => ({
|
||||
...prev,
|
||||
[runId]: {
|
||||
...prev[runId],
|
||||
isExpanded: !prev[runId]?.isExpanded,
|
||||
},
|
||||
}));
|
||||
};
|
||||
|
||||
const calculateThreadTokens = (messages: AgentMessageConfig[]) => {
|
||||
return messages.reduce((total, msg) => {
|
||||
if (!msg.models_usage) return total;
|
||||
return (
|
||||
total +
|
||||
(msg.models_usage.prompt_tokens || 0) +
|
||||
(msg.models_usage.completion_tokens || 0)
|
||||
);
|
||||
}, 0);
|
||||
};
|
||||
|
||||
const getStatusIcon = (status: ThreadState["status"]) => {
|
||||
switch (status) {
|
||||
case "streaming":
|
||||
return (
|
||||
<div className="inline-block mr-1">
|
||||
<Loader2
|
||||
size={20}
|
||||
className="inline-block mr-1 text-accent animate-spin"
|
||||
/>{" "}
|
||||
<span className="inline-block mr-2">Processing</span>{" "}
|
||||
<LoadingDots size={8} />
|
||||
</div>
|
||||
);
|
||||
case "awaiting_input": // New status
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<MessageSquare
|
||||
size={20}
|
||||
className="inline-block mr-1 text-accent"
|
||||
/>{" "}
|
||||
Waiting for your input
|
||||
</div>
|
||||
);
|
||||
case "complete":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<CheckCircle size={20} className="inline-block mr-1 text-accent" />{" "}
|
||||
Task completed
|
||||
</div>
|
||||
);
|
||||
case "error":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<AlertTriangle
|
||||
size={20}
|
||||
className="inline-block mr-1 text-red-500"
|
||||
/>{" "}
|
||||
An error occurred.
|
||||
</div>
|
||||
);
|
||||
case "cancelled":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<StopCircle size={20} className="inline-block mr-1 text-red-500" />{" "}
|
||||
Task was cancelled.
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6 p-4 h-full">
|
||||
{messagePairs.map(({ userMessage, botMessage }, pairIndex) => {
|
||||
const isLast = pairIndex === messagePairs.length - 1;
|
||||
const thread = threadMessages[botMessage.run_id];
|
||||
const hasThread = thread && thread.messages.length > 0;
|
||||
const isStreaming = thread?.status === "streaming";
|
||||
const isAwaitingInput = thread?.status === "awaiting_input"; // New check
|
||||
|
||||
const isFirstMessage = pairIndex === 0;
|
||||
|
||||
return (
|
||||
<div key={`pair-${botMessage.run_id}`} className="space-y-6">
|
||||
{/* User message */}
|
||||
{
|
||||
<div
|
||||
className={`${
|
||||
isFirstMessage ? "mb-2" : "mt-8"
|
||||
} mb-4 pt-2 border-t border-dashed border-secondary`}
|
||||
>
|
||||
{/* <div>Task Run 1. </div> */}
|
||||
<div className="text-xs text-secondary">
|
||||
Run {pairIndex + 1}
|
||||
{!isFirstMessage && (
|
||||
<>
|
||||
{" "}
|
||||
|{" "}
|
||||
<TriangleAlertIcon className="w-4 h-4 -mt-1 inline-block mr-1 ml-1" />{" "}
|
||||
Note: Each run does not share data with previous runs in
|
||||
the same session yet.{" "}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
<div className="flex flex-col items-end">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-sm font-medium text-primary">You</span>
|
||||
<div className="p-1.5 rounded bg-secondary text-accent">
|
||||
<User size={20} />
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<RenderMessage message={userMessage.config} isLast={false} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Team response */}
|
||||
<div className="flex flex-col items-start">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<div className="p-1.5 rounded bg-secondary text-primary">
|
||||
<GroupIcon size={20} />
|
||||
</div>
|
||||
<span className="text-sm font-medium text-primary">
|
||||
Agent Team
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Main response container */}
|
||||
<div className="w-full">
|
||||
<div className="p-4 bg-tertiary bordder border-secondary rounded">
|
||||
<div className="text-primary">
|
||||
{getStatusIcon(thread?.status)}{" "}
|
||||
{!isAwaitingInput && thread?.finalResult?.content}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Thread section */}
|
||||
{hasThread && (
|
||||
<div className="mt-2 pl-4 border-l-2 border-secondary/30">
|
||||
<div className="flex pt-2">
|
||||
<div className="flex-1">
|
||||
<button
|
||||
onClick={() => toggleThread(botMessage.run_id)}
|
||||
className="flex items-center gap-1 text-sm text-secondary hover:text-primary transition-colors"
|
||||
>
|
||||
<MessageSquare size={16} />
|
||||
<span className="text-accent">
|
||||
{thread.isExpanded ? "Hide" : "Show"}
|
||||
</span>{" "}
|
||||
agent discussion
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="text-sm text-secondary">
|
||||
{calculateThreadTokens(thread.messages)} tokens |{" "}
|
||||
{thread.messages.length} messages
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-row gap-4">
|
||||
<div className="flex-1">
|
||||
{thread.isExpanded && (
|
||||
<ThreadView
|
||||
thread={thread}
|
||||
isStreaming={isStreaming}
|
||||
runId={botMessage.run_id}
|
||||
onCancel={onCancel}
|
||||
onInputResponse={onInputResponse} // Pass through the new prop
|
||||
threadContainerRef={(el) =>
|
||||
(threadContainerRefs.current[botMessage.run_id] =
|
||||
el)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="bg-tertiary flex-1 rounded mt-2">
|
||||
{teamConfig && thread.isExpanded && (
|
||||
<AgentFlow
|
||||
teamConfig={teamConfig}
|
||||
messages={thread.messages}
|
||||
threadState={thread}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
{messages.length === 0 && !loading && (
|
||||
<div className="text-center text-secondary h-full">
|
||||
<div className="text-sm mt-4">Send a message to begin!</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
|
@ -58,9 +58,9 @@ const RenderToolCall: React.FC<{ content: FunctionCall[] }> = ({ content }) => (
|
|||
{content.map((call) => (
|
||||
<div key={call.id} className="border rounded p-2">
|
||||
<div className="font-medium">Function: {call.name}</div>
|
||||
<pre className="text-sm mt-1 bg-secondary p-2 rounded">
|
||||
<div className="text-sm mt-1 bg-secondary p-2 rounded">
|
||||
{JSON.stringify(JSON.parse(call.arguments), null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
@ -90,9 +90,9 @@ const RenderToolResult: React.FC<{ content: FunctionExecutionResult[] }> = ({
|
|||
}) => (
|
||||
<div className="space-y-2">
|
||||
{content.map((result) => (
|
||||
<div key={result.call_id} className=" rounded p-2">
|
||||
<div key={result.call_id} className=" rounded p-2 ">
|
||||
<div className="font-medium">Result ID: {result.call_id}</div>
|
||||
<pre className="text-sm mt-1 bg-secondary p-2 border rounded">
|
||||
<pre className="text-sm mt-1 bg-secondary p-2 border rounded scroll overflow-x-scroll">
|
||||
{result.content}
|
||||
</pre>
|
||||
</div>
|
||||
|
@ -105,6 +105,7 @@ export const RenderMessage: React.FC<MessageProps> = ({
|
|||
isLast = false,
|
||||
className = "",
|
||||
}) => {
|
||||
if (!message) return null;
|
||||
const isUser = messageUtils.isUser(message.source);
|
||||
const content = message.content;
|
||||
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
import React, { useState, useRef, useEffect } from "react";
|
||||
import {
|
||||
StopCircle,
|
||||
MessageSquare,
|
||||
Loader2,
|
||||
CheckCircle,
|
||||
AlertTriangle,
|
||||
TriangleAlertIcon,
|
||||
GroupIcon,
|
||||
} from "lucide-react";
|
||||
import { Run, Message, TeamConfig } from "../../../types/datamodel";
|
||||
import AgentFlow from "./agentflow/agentflow";
|
||||
import { RenderMessage } from "./rendermessage";
|
||||
import LoadingDots from "../../shared/atoms";
|
||||
import InputRequestView from "./inputrequest";
|
||||
import { Tooltip } from "antd";
|
||||
|
||||
interface RunViewProps {
|
||||
run: Run;
|
||||
teamConfig?: TeamConfig;
|
||||
onInputResponse?: (response: string) => void;
|
||||
onCancel?: () => void;
|
||||
isFirstRun?: boolean;
|
||||
}
|
||||
|
||||
const RunView: React.FC<RunViewProps> = ({
|
||||
run,
|
||||
onInputResponse,
|
||||
onCancel,
|
||||
teamConfig,
|
||||
isFirstRun = false,
|
||||
}) => {
|
||||
const [isExpanded, setIsExpanded] = useState(true);
|
||||
const threadContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const isActive = run.status === "active" || run.status === "awaiting_input";
|
||||
|
||||
// Replace existing scroll effect with this simpler one
|
||||
useEffect(() => {
|
||||
setTimeout(() => {
|
||||
if (threadContainerRef.current) {
|
||||
threadContainerRef.current.scrollTo({
|
||||
top: threadContainerRef.current.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
}, 450);
|
||||
}, [run.messages]); // Only depend on messages changing
|
||||
|
||||
const calculateThreadTokens = (messages: Message[]) => {
|
||||
return messages.reduce((total, msg) => {
|
||||
if (!msg.config.models_usage) return total;
|
||||
return (
|
||||
total +
|
||||
(msg.config.models_usage.prompt_tokens || 0) +
|
||||
(msg.config.models_usage.completion_tokens || 0)
|
||||
);
|
||||
}, 0);
|
||||
};
|
||||
|
||||
const getStatusIcon = (status: Run["status"]) => {
|
||||
switch (status) {
|
||||
case "active":
|
||||
return (
|
||||
<div className="inline-block mr-1">
|
||||
<Loader2
|
||||
size={20}
|
||||
className="inline-block mr-1 text-accent animate-spin"
|
||||
/>
|
||||
<span className="inline-block mr-2 ml-1 ">Processing</span>
|
||||
<LoadingDots size={8} />
|
||||
</div>
|
||||
);
|
||||
case "awaiting_input":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<MessageSquare
|
||||
size={20}
|
||||
className="inline-block mr-2 text-accent"
|
||||
/>
|
||||
<span className="inline-block mr-2">Waiting for your input </span>
|
||||
<LoadingDots size={8} />
|
||||
</div>
|
||||
);
|
||||
case "complete":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<CheckCircle size={20} className="inline-block mr-2 text-accent" />
|
||||
Task completed
|
||||
</div>
|
||||
);
|
||||
case "error":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<AlertTriangle
|
||||
size={20}
|
||||
className="inline-block mr-2 text-red-500"
|
||||
/>
|
||||
{run.error_message || "An error occurred"}
|
||||
</div>
|
||||
);
|
||||
case "stopped":
|
||||
return (
|
||||
<div className="text-sm mb-2">
|
||||
<StopCircle size={20} className="inline-block mr-2 text-red-500" />
|
||||
Task was stopped
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6 mt-6 mr-2 ">
|
||||
{/* Run Header */}
|
||||
<div
|
||||
className={`${
|
||||
isFirstRun ? "mb-2" : "mt-8"
|
||||
} mb-4 pt-2 border-t border-dashed border-secondary`}
|
||||
>
|
||||
<div className="text-xs text-secondary">
|
||||
<Tooltip
|
||||
title={
|
||||
<div className="text-xs">
|
||||
<div>ID: {run.id}</div>
|
||||
<div>Created: {new Date(run.created_at).toLocaleString()}</div>
|
||||
<div>Status: {run.status}</div>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<span className="cursor-help">Run ...{run.id.slice(-6)}</span>
|
||||
</Tooltip>
|
||||
{!isFirstRun && (
|
||||
<>
|
||||
{" "}
|
||||
|{" "}
|
||||
<TriangleAlertIcon className="w-4 h-4 -mt-1 inline-block mr-1 ml-1" />
|
||||
Note: Each run does not share data with previous runs in the same
|
||||
session yet.
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* User Message */}
|
||||
<div className="flex flex-col items-end w-full">
|
||||
<div className="w-full">
|
||||
<RenderMessage message={run.task} isLast={false} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Team Response */}
|
||||
<div className="flex flex-col items-start">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<div className="p-1.5 rounded bg-secondary text-primary">
|
||||
<GroupIcon size={20} />
|
||||
</div>
|
||||
<span className="text-sm font-medium text-primary">Agent Team</span>
|
||||
</div>
|
||||
|
||||
<div className=" w-full">
|
||||
{/* Main Response Container */}
|
||||
<div className="p-4 bg-secondary border border-secondary rounded">
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div className="text-primary">{getStatusIcon(run.status)}</div>
|
||||
|
||||
{/* Cancel Button - More prominent placement */}
|
||||
{isActive && onCancel && (
|
||||
<button
|
||||
onClick={onCancel}
|
||||
className="px-4 text-sm py-2 bg-red-500 hover:bg-red-600 text-white rounded-md transition-colors flex items-center gap-2"
|
||||
>
|
||||
<StopCircle size={16} />
|
||||
Cancel Run
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Final Response */}
|
||||
{run.status !== "awaiting_input" && run.status !== "active" && (
|
||||
<div className="text-sm">
|
||||
<div className="text-xs mb-1 text-secondary -mt-2 border rounded p-2">
|
||||
Stop reason: {run.team_result?.task_result?.stop_reason}
|
||||
</div>
|
||||
{run.messages[run.messages.length - 1]?.config?.content + ""}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Thread Section */}
|
||||
<div className="">
|
||||
{run.messages.length > 0 && (
|
||||
<div className="mt-2 pl-4 border-l-2 border-secondary/30">
|
||||
<div className="flex pt-2">
|
||||
<div className="flex-1">
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="flex items-center gap-1 text-sm text-secondary hover:text-primary transition-colors"
|
||||
>
|
||||
<MessageSquare size={16} />
|
||||
<span className="text-accent">
|
||||
{isExpanded ? "Hide" : "Show"}
|
||||
</span>{" "}
|
||||
agent discussion
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="text-sm text-secondary">
|
||||
{calculateThreadTokens(run.messages)} tokens |{" "}
|
||||
{run.messages.length} messages
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="flex flex-row gap-4">
|
||||
{/* Messages Thread */}
|
||||
<div
|
||||
ref={threadContainerRef}
|
||||
className="flex-1 mt-2 overflow-y-auto max-h-[400px] scroll-smooth scroll pb-2 relative"
|
||||
>
|
||||
<div id="scroll-gradient" className="scroll-gradient h-8">
|
||||
{" "}
|
||||
<span className=" inline-block h-6"></span>{" "}
|
||||
</div>
|
||||
{run.messages.map((msg, idx) => (
|
||||
<div
|
||||
key={"message_id" + idx + run.id}
|
||||
className=" mr-2"
|
||||
>
|
||||
<RenderMessage
|
||||
message={msg.config}
|
||||
isLast={idx === run.messages.length - 1}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Input Request UI */}
|
||||
{run.status === "awaiting_input" && onInputResponse && (
|
||||
<div className="mt-4 mr-2">
|
||||
<InputRequestView
|
||||
prompt="Type your response..."
|
||||
onSubmit={onInputResponse}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="text-primary mt-2">
|
||||
<div className="w-4 h-4 inline-block border-secondary rounded-bl-lg border-l-2 border-b-2"></div>{" "}
|
||||
<div className="inline-block ">
|
||||
{getStatusIcon(run.status)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Agent Flow Visualization */}
|
||||
<div className="bg-tertiary flex-1 rounded mt-2">
|
||||
{teamConfig && (
|
||||
<AgentFlow teamConfig={teamConfig} run={run} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default RunView;
|
|
@ -1,20 +1,18 @@
|
|||
import React, { ReactNode } from "react";
|
||||
import Markdown from "react-markdown";
|
||||
import React from "react";
|
||||
|
||||
interface MarkdownViewProps {
|
||||
children: string;
|
||||
content: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const MarkdownView: React.FC<MarkdownViewProps> = ({
|
||||
children,
|
||||
content,
|
||||
className = "",
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
className={`text-sm w-full prose dark:prose-invert text-primary rounded p-2 ${className}`}
|
||||
>
|
||||
<Markdown>{children}</Markdown>
|
||||
<div className={`text-sm w-full text-primary rounded ${className}`}>
|
||||
<Markdown>{content}</Markdown>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { Session } from "../../../types/datamodel";
|
||||
import { Session, SessionRuns } from "../../../types/datamodel";
|
||||
import { getServerUrl } from "../../../utils";
|
||||
|
||||
export class SessionAPI {
|
||||
|
@ -83,6 +83,23 @@ export class SessionAPI {
|
|||
return data.data;
|
||||
}
|
||||
|
||||
// session runs with messages
|
||||
async getSessionRuns(
|
||||
sessionId: number,
|
||||
userId: string
|
||||
): Promise<SessionRuns> {
|
||||
const response = await fetch(
|
||||
`${this.getBaseUrl()}/sessions/${sessionId}/runs?user_id=${userId}`,
|
||||
{
|
||||
headers: this.getHeaders(),
|
||||
}
|
||||
);
|
||||
const data = await response.json();
|
||||
if (!data.status)
|
||||
throw new Error(data.message || "Failed to fetch session runs");
|
||||
return data.data; // Returns { runs: RunMessage[] }
|
||||
}
|
||||
|
||||
async deleteSession(sessionId: number, userId: string): Promise<void> {
|
||||
const response = await fetch(
|
||||
`${this.getBaseUrl()}/sessions/${sessionId}?user_id=${userId}`,
|
||||
|
|
|
@ -8,7 +8,7 @@ const IndexPage = ({ data }: any) => {
|
|||
return (
|
||||
<Layout meta={data.site.siteMetadata} title="Home" link={"/"}>
|
||||
<main style={{ height: "100%" }} className=" h-full ">
|
||||
<ChatView initMessages={[]} />
|
||||
<ChatView />
|
||||
</main>
|
||||
</Layout>
|
||||
);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -16,17 +16,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_x8C5nib1PJkMZGQ6zrNUlfa0', arguments='{\"city\":\"New York\"}', name='get_weather')]), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_x8C5nib1PJkMZGQ6zrNUlfa0')]), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=13), content='Would you like to know the weather in any other city?')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=1.9984567165374756\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogenstudio.teammanager import TeamManager \n",
|
||||
" \n",
|
||||
|
@ -37,22 +29,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='What is the weather in New York?'\n",
|
||||
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15) content=[FunctionCall(id='call_Gwnfsa8ndnOsXTvRECTr92hr', arguments='{\"city\":\"New York\"}', name='get_weather')]\n",
|
||||
"source='writing_agent' models_usage=None content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_Gwnfsa8ndnOsXTvRECTr92hr')]\n",
|
||||
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14) content='The weather in New York is currently 73 degrees and sunny.'\n",
|
||||
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=123, completion_tokens=14) content='The weather in New York is currently 73 degrees and sunny.'\n",
|
||||
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_Gwnfsa8ndnOsXTvRECTr92hr', arguments='{\"city\":\"New York\"}', name='get_weather')]), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_Gwnfsa8ndnOsXTvRECTr92hr')]), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=2.363379955291748\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result_stream = wm.run_stream(task=\"What is the weather in New York?\", team_config=\"team.json\") \n",
|
||||
"async for response in result_stream:\n",
|
||||
|
@ -74,28 +53,32 @@
|
|||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO [alembic.runtime.migration] Context impl SQLiteImpl.\n",
|
||||
"INFO [alembic.runtime.migration] Will assume non-transactional DDL.\n",
|
||||
"\u001b[32m2024-11-14 09:06:25.242\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.schema_manager\u001b[0m:\u001b[36mupgrade_schema\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mSchema upgraded successfully to head\u001b[0m\n",
|
||||
"\u001b[32m2024-11-14 09:06:25.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.db_manager\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m34\u001b[0m - \u001b[1mDatabase schema was upgraded automatically\u001b[0m\n",
|
||||
"\u001b[32m2024-11-14 09:06:25.244\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.db_manager\u001b[0m:\u001b[36mcreate_db_and_tables\u001b[0m:\u001b[36m108\u001b[0m - \u001b[1mDatabase tables created successfully\u001b[0m\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Response(message='Database is ready', status=True, data=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogenstudio.database import DatabaseManager \n",
|
||||
"import os \n",
|
||||
"# delete database\n",
|
||||
"# if os.path.exists(\"test.db\"):\n",
|
||||
"# os.remove(\"test.db\") \n",
|
||||
"\n",
|
||||
"os.makedirs(\"test\", exist_ok=True)\n",
|
||||
"# create a database\n",
|
||||
"dbmanager = DatabaseManager(engine_uri=\"sqlite:///test.db\")\n",
|
||||
"dbmanager.create_db_and_tables() "
|
||||
"dbmanager = DatabaseManager(engine_uri=\"sqlite:///test.db\", base_dir=\"test\")\n",
|
||||
"dbmanager.initialize_database() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -145,14 +128,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"53 teams in database\n"
|
||||
"2 teams in database\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -172,7 +155,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -194,17 +177,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"message='Directory import complete' status=True data=[{'component': 'team', 'status': True, 'message': 'Team Created Successfully', 'id': 54}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = await config_manager.import_directory(\".\", user_id=\"user_id\", check_exists=False)\n",
|
||||
"print(result)"
|
||||
|
@ -212,17 +187,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"54 teams in database\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"all_teams = dbmanager.get(Team)\n",
|
||||
"print(len(all_teams.data), \"teams in database\")"
|
||||
|
@ -237,7 +204,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -282,22 +249,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='Plan a 3 day trip to Nepal.'\n",
|
||||
"source='planner_agent' models_usage=RequestUsage(prompt_tokens=45, completion_tokens=54) content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"\n",
|
||||
"source='local_agent' models_usage=RequestUsage(prompt_tokens=116, completion_tokens=54) content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"\n",
|
||||
"source='language_agent' models_usage=RequestUsage(prompt_tokens=201, completion_tokens=45) content=\"Your travel plan lacks a mention of dealing with potential language barriers; it might be useful to learn some phrases in Nepali, as it's the official language of Nepal, or have a local translation app handy during your trip.\"\n",
|
||||
"source='travel_summary_agent' models_usage=RequestUsage(prompt_tokens=270, completion_tokens=237) content='Day 1: Start your adventure in the capital city, Kathmandu. Take a guided tour of Kathmandu Valley to explore its UNESCO World Heritage Sites, such as the Durbar Squares, Swayambhunath Stupa, and Boudhanath Stupa. Engage with the locals and sample some traditional Nepalese cuisine.\\n\\nDay 2: Proceed to Pokhara, known for its stunning natural beauty. Visit the iconic Phewa Lake and enjoy a boat ride, then trek to the Peace Pagoda for a panoramic view of the city. Round off the day with a visit to the fascinating Pokhara Mountain Museum.\\n\\nDay 3: Travel to Chitwan National Park for a memorable safari. Explore the diverse wildlife and lush vegetation that make the park a UNESCO World Heritage site. Be on the lookout for rhinos, Bengal tigers, and a multitude of bird species.\\n\\nNote: Communication is key to enjoying your trip. The official language of Nepal is Nepali. It can be helpful to learn a few basic phrases or carry a translation app to help you interact with the local people and enrich your cultural experience.\\n\\nTERMINATE.'\n",
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Plan a 3 day trip to Nepal.'), TextMessage(source='planner_agent', models_usage=RequestUsage(prompt_tokens=45, completion_tokens=54), content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"), TextMessage(source='local_agent', models_usage=RequestUsage(prompt_tokens=116, completion_tokens=54), content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"), TextMessage(source='language_agent', models_usage=RequestUsage(prompt_tokens=201, completion_tokens=45), content=\"Your travel plan lacks a mention of dealing with potential language barriers; it might be useful to learn some phrases in Nepali, as it's the official language of Nepal, or have a local translation app handy during your trip.\"), TextMessage(source='travel_summary_agent', models_usage=RequestUsage(prompt_tokens=270, completion_tokens=237), content='Day 1: Start your adventure in the capital city, Kathmandu. Take a guided tour of Kathmandu Valley to explore its UNESCO World Heritage Sites, such as the Durbar Squares, Swayambhunath Stupa, and Boudhanath Stupa. Engage with the locals and sample some traditional Nepalese cuisine.\\n\\nDay 2: Proceed to Pokhara, known for its stunning natural beauty. Visit the iconic Phewa Lake and enjoy a boat ride, then trek to the Peace Pagoda for a panoramic view of the city. Round off the day with a visit to the fascinating Pokhara Mountain Museum.\\n\\nDay 3: Travel to Chitwan National Park for a memorable safari. Explore the diverse wildlife and lush vegetation that make the park a UNESCO World Heritage site. Be on the lookout for rhinos, Bengal tigers, and a multitude of bird species.\\n\\nNote: Communication is key to enjoying your trip. The official language of Nepal is Nepali. It can be helpful to learn a few basic phrases or carry a translation app to help you interact with the local people and enrich your cultural experience.\\n\\nTERMINATE.')], stop_reason=\"Text 'TERMINATE' mentioned\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"result = group_chat.run_stream(task=\"Plan a 3 day trip to Nepal.\")\n",
|
||||
|
@ -317,7 +271,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -330,19 +284,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='hello there'\n",
|
||||
"source='user_agent' models_usage=None content='Hello World thereEnter your response: '\n",
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='hello there'), TextMessage(source='user_agent', models_usage=None, content='Hello World thereEnter your response: ')], stop_reason=None)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.base import CancellationToken \n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
|
|
|
@ -23,7 +23,8 @@ dependencies = [
|
|||
"pydantic-settings",
|
||||
"fastapi",
|
||||
"typer",
|
||||
"uvicorn",
|
||||
"uvicorn",
|
||||
"aiofiles",
|
||||
"python-dotenv",
|
||||
"websockets",
|
||||
"numpy < 2.0.0",
|
||||
|
@ -69,3 +70,21 @@ filterwarnings = [
|
|||
|
||||
[project.scripts]
|
||||
autogenstudio = "autogenstudio.cli:run"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
exclude = ["build", "dist"]
|
||||
include = [
|
||||
"autogenstudio/**"
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["B008"]
|
||||
|
||||
|
||||
[tool.poe.tasks]
|
||||
fmt = "ruff format"
|
||||
format.ref = "fmt"
|
||||
lint = "ruff check"
|
||||
test = "pytest -n 0"
|
|
@ -6,10 +6,10 @@ from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
|||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
|
||||
from autogenstudio.datamodel import (
|
||||
from autogenstudio.datamodel.types import (
|
||||
AgentConfig, ModelConfig, TeamConfig, ToolConfig, TerminationConfig,
|
||||
ModelTypes, AgentTypes, TeamTypes, TerminationTypes, ToolTypes,
|
||||
ComponentType
|
||||
ComponentTypes
|
||||
)
|
||||
from autogenstudio.database import ComponentFactory
|
||||
|
||||
|
@ -41,7 +41,7 @@ def calculator(a: int, b: int, operation: str = '+') -> int:
|
|||
raise ValueError("Invalid operation")
|
||||
""",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
@ -52,7 +52,7 @@ def sample_model_config():
|
|||
model_type=ModelTypes.OPENAI,
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
component_type=ComponentType.MODEL,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
@ -65,7 +65,7 @@ def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: To
|
|||
system_message="You are a helpful assistant",
|
||||
model_client=sample_model_config,
|
||||
tools=[sample_tool_config],
|
||||
component_type=ComponentType.AGENT,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
@ -75,7 +75,7 @@ def sample_termination_config():
|
|||
return TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=10,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
@ -88,7 +88,7 @@ def sample_team_config(sample_agent_config: AgentConfig, sample_termination_conf
|
|||
participants=[sample_agent_config],
|
||||
termination_condition=sample_termination_config,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentType.TEAM,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
@ -115,7 +115,7 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
|||
description="",
|
||||
content="",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
|
@ -125,7 +125,7 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
|||
description="Invalid function",
|
||||
content="def invalid_func(): return invalid syntax",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -154,7 +154,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
max_msg_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(max_msg_config)
|
||||
|
@ -164,7 +164,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
# Test StopMessageTermination
|
||||
stop_msg_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(stop_msg_config)
|
||||
|
@ -174,13 +174,108 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
text_mention_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(text_mention_config)
|
||||
assert isinstance(termination, TextMentionTermination)
|
||||
assert termination._text == "DONE"
|
||||
|
||||
# Test AND combination
|
||||
and_combo_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="and",
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(and_combo_config)
|
||||
assert termination is not None
|
||||
|
||||
# Test OR combination
|
||||
or_combo_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="or",
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(or_combo_config)
|
||||
assert termination is not None
|
||||
|
||||
# Test invalid combinations
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[], # Empty conditions
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="invalid", # type: ignore
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
# Test missing operator
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_team(component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig):
|
||||
|
@ -201,13 +296,13 @@ async def test_load_team(component_factory: ComponentFactory, sample_team_config
|
|||
system_message="You are another helpful assistant",
|
||||
model_client=sample_model_config,
|
||||
tools=sample_team_config.participants[0].tools,
|
||||
component_type=ComponentType.AGENT,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
termination_condition=sample_team_config.termination_condition,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentType.TEAM,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
)
|
||||
team = await component_factory.load_team(selector_team_config)
|
||||
|
@ -223,7 +318,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
name="test",
|
||||
agent_type="InvalidAgent", # type: ignore
|
||||
system_message="test",
|
||||
component_type=ComponentType.AGENT,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
|
@ -233,7 +328,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
name="test",
|
||||
team_type="InvalidTeam", # type: ignore
|
||||
participants=[],
|
||||
component_type=ComponentType.TEAM,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
|
@ -241,6 +336,6 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type="InvalidTermination", # type: ignore
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
import os
|
||||
import asyncio
|
||||
import pytest
|
||||
from sqlmodel import Session, text, select
|
||||
from typing import Generator
|
||||
from datetime import datetime
|
||||
|
||||
from autogenstudio.database import DatabaseManager
|
||||
from autogenstudio.datamodel import (
|
||||
Model, ModelConfig, Agent, AgentConfig, Tool, ToolConfig,
|
||||
Team, TeamConfig, ModelTypes, AgentTypes, TeamTypes, ComponentType,
|
||||
TerminationConfig, TerminationTypes, LinkTypes, ToolTypes
|
||||
from autogenstudio.datamodel.types import (
|
||||
ModelConfig, AgentConfig, ToolConfig,
|
||||
TeamConfig, ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
|
||||
TerminationConfig, TerminationTypes, ToolTypes
|
||||
)
|
||||
from autogenstudio.datamodel.db import Model, Tool, Agent, Team, LinkTypes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -18,13 +19,13 @@ def test_db() -> Generator[DatabaseManager, None, None]:
|
|||
db_path = "test.db"
|
||||
db = DatabaseManager(f"sqlite:///{db_path}")
|
||||
db.reset_db()
|
||||
db.create_db_and_tables()
|
||||
# Initialize database instead of create_db_and_tables
|
||||
db.initialize_database(auto_upgrade=False)
|
||||
yield db
|
||||
# Clean up
|
||||
asyncio.run(db.close())
|
||||
db.reset_db()
|
||||
try:
|
||||
# Close database connections before removing file
|
||||
db.engine.dispose()
|
||||
# Remove the database file
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
except Exception as e:
|
||||
|
@ -44,7 +45,7 @@ def sample_model(test_user: str) -> Model:
|
|||
config=ModelConfig(
|
||||
model="gpt-4",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentType.MODEL,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -60,7 +61,7 @@ def sample_tool(test_user: str) -> Tool:
|
|||
description="A test tool",
|
||||
content="async def test_func(x: str) -> str:\n return f'Test {x}'",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -76,7 +77,7 @@ def sample_agent(test_user: str, sample_model: Model, sample_tool: Tool) -> Agen
|
|||
agent_type=AgentTypes.ASSISTANT,
|
||||
model_client=ModelConfig.model_validate(sample_model.config),
|
||||
tools=[ToolConfig.model_validate(sample_tool.config)],
|
||||
component_type=ComponentType.AGENT,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -92,11 +93,11 @@ def sample_team(test_user: str, sample_agent: Agent) -> Team:
|
|||
participants=[AgentConfig.model_validate(sample_agent.config)],
|
||||
termination_condition=TerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
).model_dump(),
|
||||
team_type=TeamTypes.ROUND_ROBIN,
|
||||
component_type=ComponentType.TEAM,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -144,7 +145,7 @@ class TestDatabaseOperations:
|
|||
config=ModelConfig(
|
||||
model="gpt-4",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentType.MODEL,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -153,7 +154,7 @@ class TestDatabaseOperations:
|
|||
config=ModelConfig(
|
||||
model="gpt-3.5",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentType.MODEL,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
version="1.0.0"
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -181,3 +182,59 @@ class TestDatabaseOperations:
|
|||
model_names = [model.config["model"] for model in linked_models.data]
|
||||
assert "gpt-4" in model_names
|
||||
assert "gpt-3.5" in model_names
|
||||
|
||||
def test_upsert_operations(self, test_db: DatabaseManager, sample_model: Model):
|
||||
"""Test upsert for both create and update scenarios"""
|
||||
# Test Create
|
||||
response = test_db.upsert(sample_model)
|
||||
assert response.status is True
|
||||
assert "Created Successfully" in response.message
|
||||
|
||||
# Test Update
|
||||
sample_model.config["model"] = "gpt-4-turbo"
|
||||
response = test_db.upsert(sample_model)
|
||||
assert response.status is True
|
||||
assert "Updated Successfully" in response.message
|
||||
|
||||
# Verify Update
|
||||
result = test_db.get(Model, {"id": sample_model.id})
|
||||
assert result.status is True
|
||||
assert result.data[0].config["model"] == "gpt-4-turbo"
|
||||
|
||||
def test_delete_operations(self, test_db: DatabaseManager, sample_model: Model):
|
||||
"""Test delete with various filters"""
|
||||
# First insert the model
|
||||
test_db.upsert(sample_model)
|
||||
|
||||
# Test deletion by id
|
||||
response = test_db.delete(Model, {"id": sample_model.id})
|
||||
assert response.status is True
|
||||
assert "Deleted Successfully" in response.message
|
||||
|
||||
# Verify deletion
|
||||
result = test_db.get(Model, {"id": sample_model.id})
|
||||
assert len(result.data) == 0
|
||||
|
||||
# Test deletion with non-existent id
|
||||
response = test_db.delete(Model, {"id": 999999})
|
||||
assert "Row not found" in response.message
|
||||
|
||||
def test_initialize_database_scenarios(self):
|
||||
"""Test different initialize_database parameters"""
|
||||
db_path = "test_init.db"
|
||||
db = DatabaseManager(f"sqlite:///{db_path}")
|
||||
|
||||
try:
|
||||
# Test basic initialization
|
||||
response = db.initialize_database()
|
||||
assert response.status is True
|
||||
|
||||
# Test with auto_upgrade
|
||||
response = db.initialize_database(auto_upgrade=True)
|
||||
assert response.status is True
|
||||
|
||||
finally:
|
||||
asyncio.run(db.close())
|
||||
db.reset_db()
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
|
|
@ -573,9 +573,10 @@ dev = [
|
|||
|
||||
[[package]]
|
||||
name = "autogenstudio"
|
||||
version = "0.4.0.dev37"
|
||||
version = "0.4.0.dev38"
|
||||
source = { editable = "packages/autogen-studio" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "alembic" },
|
||||
{ name = "autogen-agentchat" },
|
||||
{ name = "autogen-core" },
|
||||
|
@ -605,6 +606,7 @@ web = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "alembic" },
|
||||
{ name = "autogen-agentchat", editable = "packages/autogen-agentchat" },
|
||||
{ name = "autogen-core", editable = "packages/autogen-core" },
|
||||
|
|
Loading…
Reference in New Issue