mirror of https://github.com/microsoft/autogen.git
167 lines
5.4 KiB
Python
167 lines
5.4 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Awaitable, Callable, Optional
|
|
|
|
import aiofiles
|
|
import yaml
|
|
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
|
from autogen_agentchat.base import TaskResult
|
|
from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent
|
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
|
from autogen_core import CancellationToken
|
|
from autogen_core.models import ChatCompletionClient
|
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI()
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
model_config_path = "model_config.yaml"
|
|
state_path = "team_state.json"
|
|
history_path = "team_history.json"
|
|
|
|
# Serve static files
|
|
app.mount("/static", StaticFiles(directory="."), name="static")
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Serve the chat interface HTML file."""
|
|
return FileResponse("app_team.html")
|
|
|
|
|
|
async def get_team(
|
|
user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]],
|
|
) -> RoundRobinGroupChat:
|
|
# Get model client from config.
|
|
async with aiofiles.open(model_config_path, "r") as file:
|
|
model_config = yaml.safe_load(await file.read())
|
|
model_client = ChatCompletionClient.load_component(model_config)
|
|
# Create the team.
|
|
agent = AssistantAgent(
|
|
name="assistant",
|
|
model_client=model_client,
|
|
system_message="You are a helpful assistant.",
|
|
)
|
|
yoda = AssistantAgent(
|
|
name="yoda",
|
|
model_client=model_client,
|
|
system_message="Repeat the same message in the tone of Yoda.",
|
|
)
|
|
user_proxy = UserProxyAgent(
|
|
name="user",
|
|
input_func=user_input_func, # Use the user input function.
|
|
)
|
|
team = RoundRobinGroupChat(
|
|
[agent, yoda, user_proxy],
|
|
)
|
|
# Load state from file.
|
|
if not os.path.exists(state_path):
|
|
return team
|
|
async with aiofiles.open(state_path, "r") as file:
|
|
state = json.loads(await file.read())
|
|
await team.load_state(state)
|
|
return team
|
|
|
|
|
|
async def get_history() -> list[dict[str, Any]]:
|
|
"""Get chat history from file."""
|
|
if not os.path.exists(history_path):
|
|
return []
|
|
async with aiofiles.open(history_path, "r") as file:
|
|
return json.loads(await file.read())
|
|
|
|
|
|
@app.get("/history")
|
|
async def history() -> list[dict[str, Any]]:
|
|
try:
|
|
return await get_history()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
|
|
@app.websocket("/ws/chat")
|
|
async def chat(websocket: WebSocket):
|
|
await websocket.accept()
|
|
|
|
# User input function used by the team.
|
|
async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:
|
|
data = await websocket.receive_json()
|
|
message = TextMessage.model_validate(data)
|
|
return message.content
|
|
|
|
try:
|
|
while True:
|
|
# Get user message.
|
|
data = await websocket.receive_json()
|
|
request = TextMessage.model_validate(data)
|
|
|
|
try:
|
|
# Get the team and respond to the message.
|
|
team = await get_team(_user_input)
|
|
history = await get_history()
|
|
stream = team.run_stream(task=request)
|
|
async for message in stream:
|
|
if isinstance(message, TaskResult):
|
|
continue
|
|
await websocket.send_json(message.model_dump())
|
|
if not isinstance(message, UserInputRequestedEvent):
|
|
# Don't save user input events to history.
|
|
history.append(message.model_dump())
|
|
|
|
# Save team state to file.
|
|
async with aiofiles.open(state_path, "w") as file:
|
|
state = await team.save_state()
|
|
await file.write(json.dumps(state))
|
|
|
|
# Save chat history to file.
|
|
async with aiofiles.open(history_path, "w") as file:
|
|
await file.write(json.dumps(history))
|
|
|
|
except Exception as e:
|
|
# Send error message to client
|
|
error_message = {
|
|
"type": "error",
|
|
"content": f"Error: {str(e)}",
|
|
"source": "system"
|
|
}
|
|
await websocket.send_json(error_message)
|
|
# Re-enable input after error
|
|
await websocket.send_json({
|
|
"type": "UserInputRequestedEvent",
|
|
"content": "An error occurred. Please try again.",
|
|
"source": "system"
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"content": f"Unexpected error: {str(e)}",
|
|
"source": "system"
|
|
})
|
|
except:
|
|
pass
|
|
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|