autogen/python/packages/autogen-studio/autogenstudio/web/managers/connection.py

463 lines
18 KiB
Python

import asyncio
import logging
import traceback
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional, Union
from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from autogen_core import CancellationToken
from autogen_core import Image as AGImage
from fastapi import WebSocket, WebSocketDisconnect
from ...database import DatabaseManager
from ...datamodel import (
LLMCallEventMessage,
Message,
MessageConfig,
Run,
RunStatus,
Settings,
SettingsConfig,
TeamResult,
)
from ...teammanager import TeamManager
logger = logging.getLogger(__name__)
class WebSocketManager:
"""Manages WebSocket connections and message streaming for team task execution"""
def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
self._connections: Dict[int, WebSocket] = {}
self._cancellation_tokens: Dict[int, CancellationToken] = {}
# Track explicitly closed connections
self._closed_connections: set[int] = set()
self._input_responses: Dict[int, asyncio.Queue] = {}
self._cancel_message = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="user", content="Run cancelled by user")], stop_reason="cancelled by user"
),
usage="",
duration=0,
).model_dump()
def _get_stop_message(self, reason: str) -> dict:
return TeamResult(
task_result=TaskResult(messages=[TextMessage(source="user", content=reason)], stop_reason=reason),
usage="",
duration=0,
).model_dump()
async def connect(self, websocket: WebSocket, run_id: int) -> bool:
try:
await websocket.accept()
self._connections[run_id] = websocket
self._closed_connections.discard(run_id)
# Initialize input queue for this connection
self._input_responses[run_id] = asyncio.Queue()
await self._send_message(
run_id, {"type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat()}
)
return True
except Exception as e:
logger.error(f"Connection error for run {run_id}: {e}")
return False
async def start_stream(self, run_id: int, task: str, team_config: dict) -> None:
"""Start streaming task execution with proper run management"""
if run_id not in self._connections or run_id in self._closed_connections:
raise ValueError(f"No active connection for run {run_id}")
team_manager = TeamManager()
cancellation_token = CancellationToken()
self._cancellation_tokens[run_id] = cancellation_token
final_result = None
try:
# Update run with task and status
run = await self._get_run(run_id)
# get user Settings
user_settings = await self._get_settings(run.user_id)
env_vars = SettingsConfig(**user_settings.config).environment if user_settings else None
if run:
run.task = MessageConfig(content=task, source="user").model_dump()
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)
input_func = self.create_input_func(run_id)
async for message in team_manager.run_stream(
task=task,
team_config=team_config,
input_func=input_func,
cancellation_token=cancellation_token,
env_vars=env_vars,
):
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
logger.info(f"Stream cancelled or connection closed for run {run_id}")
break
formatted_message = self._format_message(message)
if formatted_message:
await self._send_message(run_id, formatted_message)
# Save messages by concrete type
if isinstance(
message,
(
TextMessage,
MultiModalMessage,
StopMessage,
HandoffMessage,
ToolCallRequestEvent,
ToolCallExecutionEvent,
LLMCallEventMessage,
),
):
await self._save_message(run_id, message)
# Capture final result if it's a TeamResult
elif isinstance(message, TeamResult):
final_result = message.model_dump()
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
else:
logger.warning(f"No final result captured for completed run {run_id}")
await self._update_run_status(run_id, RunStatus.COMPLETE)
else:
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": self._cancel_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Update run with cancellation result
await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message)
except Exception as e:
logger.error(f"Stream error for run {run_id}: {e}")
traceback.print_exc()
await self._handle_stream_error(run_id, e)
finally:
self._cancellation_tokens.pop(run_id, None)
async def _save_message(
self, run_id: int, message: Union[BaseAgentEvent | BaseChatMessage, BaseChatMessage]
) -> None:
"""Save a message to the database"""
run = await self._get_run(run_id)
if run:
db_message = Message(
session_id=run.session_id,
run_id=run_id,
config=message.model_dump(),
user_id=None, # You might want to pass this from somewhere
)
self.db_manager.upsert(db_message)
async def _update_run(
self, run_id: int, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
) -> None:
"""Update run status and result"""
run = await self._get_run(run_id)
if run:
run.status = status
if team_result:
run.team_result = team_result
if error:
run.error_message = error
self.db_manager.upsert(run)
def create_input_func(self, run_id: int) -> Callable:
"""Creates an input function for a specific run"""
async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str:
try:
# Send input request to client
await self._send_message(
run_id,
{
"type": "input_request",
"prompt": prompt,
"data": {"source": "system", "content": prompt},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Wait for response
if run_id in self._input_responses:
response = await self._input_responses[run_id].get()
return response
else:
raise ValueError(f"No input queue for run {run_id}")
except Exception as e:
logger.error(f"Error handling input for run {run_id}: {e}")
raise
return input_handler
async def handle_input_response(self, run_id: int, response: str) -> None:
"""Handle input response from client"""
if run_id in self._input_responses:
await self._input_responses[run_id].put(response)
else:
logger.warning(f"Received input response for inactive run {run_id}")
async def stop_run(self, run_id: int, reason: str) -> None:
if run_id in self._cancellation_tokens:
logger.info(f"Stopping run {run_id}")
stop_message = self._get_stop_message(reason)
try:
# Update run record first
await self._update_run(run_id, status=RunStatus.STOPPED, team_result=stop_message)
# Then handle websocket communication if connection is active
if run_id in self._connections and run_id not in self._closed_connections:
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": stop_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Finally cancel the token
self._cancellation_tokens[run_id].cancel()
except Exception as e:
logger.error(f"Error stopping run {run_id}: {e}")
# We might want to force disconnect here if db update failed
# await self.disconnect(run_id) # Optional
async def disconnect(self, run_id: int) -> None:
"""Clean up connection and associated resources"""
logger.info(f"Disconnecting run {run_id}")
# Mark as closed before cleanup to prevent any new messages
self._closed_connections.add(run_id)
# Cancel any running tasks
await self.stop_run(run_id, "Connection closed")
# Clean up resources
self._connections.pop(run_id, None)
self._cancellation_tokens.pop(run_id, None)
self._input_responses.pop(run_id, None)
async def _send_message(self, run_id: int, message: dict) -> None:
"""Send a message through the WebSocket with connection state checking
Args:
run_id: id of the run
message: Message dictionary to send
"""
if run_id in self._closed_connections:
logger.warning(f"Attempted to send message to closed connection for run {run_id}")
return
try:
if run_id in self._connections:
websocket = self._connections[run_id]
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
await self.disconnect(run_id)
except Exception as e:
logger.error(f"Error sending message for run {run_id}: {e}, {message}")
# Don't try to send error message here to avoid potential recursive loop
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
await self.disconnect(run_id)
async def _handle_stream_error(self, run_id: int, error: Exception) -> None:
"""Handle stream errors with proper run updates"""
if run_id not in self._closed_connections:
error_result = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="system", content=str(error))],
stop_reason="An error occurred while processing this run",
),
usage="",
duration=0,
).model_dump()
await self._send_message(
run_id,
{
"type": "completion",
"status": "error",
"data": error_result,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
await self._update_run(run_id, RunStatus.ERROR, team_result=error_result, error=str(error))
def _format_message(self, message: Any) -> Optional[dict]:
"""Format message for WebSocket transmission
Args:
message: Message to format
Returns:
Optional[dict]: Formatted message or None if formatting fails
"""
try:
if isinstance(message, MultiModalMessage):
message_dump = message.model_dump()
message_dump["content"] = [
message_dump["content"][0],
{
"url": f"data:image/png;base64,{message_dump['content'][1]['data']}",
"alt": "WebSurfer Screenshot",
},
]
return {"type": "message", "data": message_dump}
elif isinstance(message, TeamResult):
return {
"type": "result",
"data": message.model_dump(),
"status": "complete",
}
elif isinstance(message, ModelClientStreamingChunkEvent):
return {"type": "message_chunk", "data": message.model_dump()}
elif isinstance(
message,
(
TextMessage,
StopMessage,
HandoffMessage,
ToolCallRequestEvent,
ToolCallExecutionEvent,
LLMCallEventMessage,
),
):
return {"type": "message", "data": message.model_dump()}
return None
except Exception as e:
logger.error(f"Message formatting error: {e}")
return None
async def _get_run(self, run_id: int) -> Optional[Run]:
"""Get run from database
Args:
run_id: id of the run to retrieve
Returns:
Optional[Run]: Run object if found, None otherwise
"""
response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False)
return response.data[0] if response.status and response.data else None
async def _get_settings(self, user_id: str) -> Optional[Settings]:
"""Get user settings from database
Args:
user_id: User ID to retrieve settings for
Returns:
Optional[dict]: User settings if found, None otherwise
"""
response = self.db_manager.get(filters={"user_id": user_id}, model_class=Settings, return_json=False)
return response.data[0] if response.status and response.data else None
async def _update_run_status(self, run_id: int, status: RunStatus, error: Optional[str] = None) -> None:
"""Update run status in database
Args:
run_id: id of the run to update
status: New status to set
error: Optional error message
"""
run = await self._get_run(run_id)
if run:
run.status = status
run.error_message = error
self.db_manager.upsert(run)
async def cleanup(self) -> None:
"""Clean up all active connections and resources when server is shutting down"""
logger.info(f"Cleaning up {len(self.active_connections)} active connections")
try:
# First cancel all running tasks
for run_id in self.active_runs.copy():
if run_id in self._cancellation_tokens:
self._cancellation_tokens[run_id].cancel()
run = await self._get_run(run_id)
if run and run.status == RunStatus.ACTIVE:
interrupted_result = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="system", content="Run interrupted by server shutdown")],
stop_reason="server_shutdown",
),
usage="",
duration=0,
).model_dump()
run.status = RunStatus.STOPPED
run.team_result = interrupted_result
self.db_manager.upsert(run)
# Then disconnect all websockets with timeout
# 10 second timeout for entire cleanup
async with asyncio.timeout(10):
for run_id in self.active_connections.copy():
try:
# Give each disconnect operation 2 seconds
async with asyncio.timeout(2):
await self.disconnect(run_id)
except asyncio.TimeoutError:
logger.warning(f"Timeout disconnecting run {run_id}")
except Exception as e:
logger.error(f"Error disconnecting run {run_id}: {e}")
except asyncio.TimeoutError:
logger.warning("WebSocketManager cleanup timed out")
except Exception as e:
logger.error(f"Error during WebSocketManager cleanup: {e}")
finally:
# Always clear internal state, even if cleanup had errors
self._connections.clear()
self._cancellation_tokens.clear()
self._closed_connections.clear()
self._input_responses.clear()
@property
def active_connections(self) -> set[int]:
"""Get set of active run IDs"""
return set(self._connections.keys()) - self._closed_connections
@property
def active_runs(self) -> set[int]:
"""Get set of runs with active cancellation tokens"""
return set(self._cancellation_tokens.keys())