mirror of https://github.com/microsoft/autogen.git
248 lines
8.3 KiB
Python
248 lines
8.3 KiB
Python
# managers/connection.py
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from typing import Dict, Optional, Any
|
|
from uuid import UUID
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from ...datamodel import Run, RunStatus, TeamResult
|
|
from ...database import DatabaseManager
|
|
from autogen_agentchat.messages import InnerMessage, ChatMessage
|
|
from autogen_core.base import CancellationToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Manages WebSocket connections and message streaming for team task execution"""
|
|
|
|
def __init__(self, db_manager: DatabaseManager):
|
|
self.db_manager = db_manager
|
|
self._connections: Dict[UUID, WebSocket] = {}
|
|
self._cancellation_tokens: Dict[UUID, CancellationToken] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
|
|
"""Initialize WebSocket connection for a run
|
|
|
|
Args:
|
|
websocket: The WebSocket connection to initialize
|
|
run_id: UUID of the run to associate with this connection
|
|
|
|
Returns:
|
|
bool: True if connection was successful, False otherwise
|
|
"""
|
|
try:
|
|
await websocket.accept()
|
|
self._connections[run_id] = websocket
|
|
|
|
run = await self._get_run(run_id)
|
|
if run:
|
|
run.status = RunStatus.ACTIVE
|
|
self.db_manager.upsert(run)
|
|
|
|
await self._send_message(run_id, {
|
|
"type": "system",
|
|
"status": "connected",
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Connection error for run {run_id}: {e}")
|
|
return False
|
|
|
|
async def start_stream(
|
|
self,
|
|
run_id: UUID,
|
|
team_manager: Any,
|
|
task: str,
|
|
team_config: dict
|
|
) -> None:
|
|
"""Start streaming task execution
|
|
|
|
Args:
|
|
run_id: UUID of the run
|
|
team_manager: Instance of the team manager
|
|
task: Task string to execute
|
|
team_config: Team configuration dictionary
|
|
"""
|
|
if run_id not in self._connections:
|
|
raise ValueError(f"No active connection for run {run_id}")
|
|
|
|
cancellation_token = CancellationToken()
|
|
self._cancellation_tokens[run_id] = cancellation_token
|
|
|
|
try:
|
|
async for message in team_manager.run_stream(
|
|
task=task,
|
|
team_config=team_config,
|
|
cancellation_token=cancellation_token
|
|
):
|
|
|
|
if cancellation_token.is_cancelled():
|
|
logger.info(f"Stream cancelled for run {run_id}")
|
|
break
|
|
|
|
formatted_message = self._format_message(message)
|
|
if formatted_message:
|
|
await self._send_message(run_id, formatted_message)
|
|
|
|
# Only send completion if not cancelled
|
|
if not cancellation_token.is_cancelled():
|
|
# await self._send_message(run_id, {
|
|
# "type": "completion",
|
|
# "status": "complete",
|
|
# "timestamp": datetime.now(timezone.utc).isoformat()
|
|
# })
|
|
await self._update_run_status(run_id, RunStatus.COMPLETE)
|
|
else:
|
|
await self._send_message(run_id, {
|
|
"type": "completion",
|
|
"status": "cancelled",
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
await self._update_run_status(run_id, RunStatus.STOPPED)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream error for run {run_id}: {e}")
|
|
await self._handle_stream_error(run_id, e)
|
|
|
|
finally:
|
|
self._cancellation_tokens.pop(run_id, None)
|
|
|
|
async def stop_run(self, run_id: UUID) -> None:
|
|
"""Stop a running task"""
|
|
if run_id in self._cancellation_tokens:
|
|
logger.info(f"Stopping run {run_id}")
|
|
self._cancellation_tokens[run_id].cancel()
|
|
|
|
# Send final message if connection still exists
|
|
if run_id in self._connections:
|
|
try:
|
|
await self._send_message(run_id, {
|
|
"type": "completion",
|
|
"status": "cancelled",
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
async def disconnect(self, run_id: UUID) -> None:
|
|
"""Clean up connection and associated resources"""
|
|
logger.info(f"Disconnecting run {run_id}")
|
|
|
|
# First cancel any running tasks
|
|
await self.stop_run(run_id)
|
|
|
|
# Then clean up resources without trying to close the socket again
|
|
if run_id in self._connections:
|
|
self._connections.pop(run_id, None)
|
|
self._cancellation_tokens.pop(run_id, None)
|
|
|
|
async def _send_message(self, run_id: UUID, message: dict) -> None:
|
|
"""Send a message through the WebSocket
|
|
|
|
Args:
|
|
run_id: UUID of the run
|
|
message: Message dictionary to send
|
|
"""
|
|
try:
|
|
if run_id in self._connections:
|
|
await self._connections[run_id].send_json(message)
|
|
except WebSocketDisconnect:
|
|
logger.warning(
|
|
f"WebSocket disconnected while sending message for run {run_id}")
|
|
await self.disconnect(run_id)
|
|
except Exception as e:
|
|
logger.error(f"Error sending message for run {run_id}: {e}")
|
|
await self._handle_stream_error(run_id, e)
|
|
|
|
async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None:
|
|
"""Handle stream errors consistently
|
|
|
|
Args:
|
|
run_id: UUID of the run
|
|
error: Exception that occurred
|
|
"""
|
|
try:
|
|
await self._send_message(run_id, {
|
|
"type": "completion",
|
|
"status": "error",
|
|
"error": str(error),
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
except Exception as send_error:
|
|
logger.error(
|
|
f"Failed to send error message for run {run_id}: {send_error}")
|
|
|
|
await self._update_run_status(run_id, RunStatus.ERROR, str(error))
|
|
|
|
def _format_message(self, message: Any) -> Optional[dict]:
|
|
"""Format message for WebSocket transmission
|
|
|
|
Args:
|
|
message: Message to format
|
|
|
|
Returns:
|
|
Optional[dict]: Formatted message or None if formatting fails
|
|
"""
|
|
try:
|
|
if isinstance(message, (InnerMessage, ChatMessage)):
|
|
return {
|
|
"type": "message",
|
|
"data": message.model_dump()
|
|
}
|
|
elif isinstance(message, TeamResult):
|
|
return {
|
|
"type": "result",
|
|
"data": message.model_dump(),
|
|
"status": "complete",
|
|
}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Message formatting error: {e}")
|
|
return None
|
|
|
|
async def _get_run(self, run_id: UUID) -> Optional[Run]:
|
|
"""Get run from database
|
|
|
|
Args:
|
|
run_id: UUID of the run to retrieve
|
|
|
|
Returns:
|
|
Optional[Run]: Run object if found, None otherwise
|
|
"""
|
|
response = self.db_manager.get(
|
|
Run, filters={"id": run_id}, return_json=False)
|
|
return response.data[0] if response.status and response.data else None
|
|
|
|
async def _update_run_status(
|
|
self,
|
|
run_id: UUID,
|
|
status: RunStatus,
|
|
error: Optional[str] = None
|
|
) -> None:
|
|
"""Update run status in database
|
|
|
|
Args:
|
|
run_id: UUID of the run to update
|
|
status: New status to set
|
|
error: Optional error message
|
|
"""
|
|
run = await self._get_run(run_id)
|
|
if run:
|
|
run.status = status
|
|
run.error_message = error
|
|
self.db_manager.upsert(run)
|
|
|
|
@property
|
|
def active_connections(self) -> set[UUID]:
|
|
"""Get set of active run IDs"""
|
|
return set(self._connections.keys())
|
|
|
|
@property
|
|
def active_runs(self) -> set[UUID]:
|
|
"""Get set of runs with active cancellation tokens"""
|
|
return set(self._cancellation_tokens.keys())
|