Add sample chat application with FastAPI (#5433)

Introduce a sample chat application using AgentChat and FastAPI,
demonstrating single-agent and team chat functionalities, along with
state persistence and conversation history management.

Resolves #5423

---------

Co-authored-by: Victor Dibia <victor.dibia@gmail.com>
Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Eric Zhu 2025-02-07 12:17:56 -08:00 committed by GitHub
parent f20ba9127d
commit abdc0da4f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 790 additions and 0 deletions

View File

@ -0,0 +1,5 @@
model_config.yaml
agent_state.json
agent_history.json
team_state.json
team_history.json

View File

@ -0,0 +1,70 @@
# AgentChat App with FastAPI
This sample demonstrates how to create a simple chat application using
[AgentChat](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/index.html)
and [FastAPI](https://fastapi.tiangolo.com/).
You will be using the following features of AgentChat:
1. Agent:
- `AssistantAgent`
- `UserProxyAgent` with a custom websocket input function
2. Team: `RoundRobinGroupChat`
3. State persistence: `save_state` and `load_state` methods of both agent and team.
## Setup
Install the required packages with OpenAI support:
```bash
pip install -U "autogen-ext[openai]" "fastapi" "uvicorn" "PyYAML"
```
To use models other than OpenAI, see the [Models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html) documentation.
Create a new file named `model_config.yaml` in the same directory as this README file to configure your model settings.
See `model_config_template.yaml` for an example.
## Chat with a single agent
To start the FastAPI server for single-agent chat, run:
```bash
python app_agent.py
```
Visit http://localhost:8001 in your browser to start chatting.
## Chat with a team of agents
To start the FastAPI server for team chat, run:
```bash
python app_team.py
```
Visit http://localhost:8002 in your browser to start chatting.
The team also includes a `UserProxyAgent` agent with a custom websocket input function
that allows the user to send messages to the team from the browser.
The team follows a round-robin strategy so each agent will take turns to respond.
When it is the user's turn, the input box will be enabled.
Once the user sends a message, the input box will be disabled and the agents
will take turns to respond.
## State persistence
The agents and team use the `load_state` and `save_state` methods to load and save
their state from and to files on each turn.
For the agent, the state is saved to and loaded from `agent_state.json`.
For the team, the state is saved to and loaded from `team_state.json`.
You can inspect the state files to see the state of the agents and team
once you have chatted with them.
When the server restarts, the agents and team will load their state from the state files
to maintain their state across restarts.
Additionally, the apps uses separate JSON files,
`agent_history.json` and `team_history.json`, to store the conversation history
for display in the browser.

View File

@ -0,0 +1,195 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AutoGen FastAPI Sample: Agent</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100vh;
background-color: #f0f0f0;
}
#chat-container {
width: 90%;
max-width: 600px;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
padding: 20px;
}
#messages {
height: 600px;
overflow-y: auto;
border-bottom: 1px solid #ddd;
margin-bottom: 20px;
}
.message {
margin: 10px 0;
}
.message.user {
text-align: right;
}
.message.assistant {
text-align: left;
}
.message.error {
color: #721c24;
background-color: #f8d7da;
border: 1px solid #f5c6cb;
padding: 10px;
border-radius: 4px;
margin: 10px 0;
}
.message.system {
color: #0c5460;
background-color: #d1ecf1;
border: 1px solid #bee5eb;
padding: 10px;
border-radius: 4px;
margin: 10px 0;
}
#input-container input:disabled,
#input-container button:disabled {
background-color: #e0e0e0;
cursor: not-allowed;
}
#input-container {
display: flex;
}
#input-container input {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
}
#input-container button {
padding: 10px 20px;
border: none;
background-color: #007bff;
color: #fff;
border-radius: 4px;
cursor: pointer;
}
</style>
</head>
<body>
<div id="chat-container">
<div id="messages"></div>
<div id="input-container">
<input type="text" id="message-input" placeholder="Type a message...">
<button onclick="sendMessage()">Send</button>
</div>
</div>
<script>
document.getElementById('message-input').addEventListener('keydown', function (event) {
if (event.key === 'Enter') {
sendMessage();
}
});
async function sendMessage() {
const input = document.getElementById('message-input');
const button = document.querySelector('#input-container button');
const message = input.value;
if (!message) return;
// Display user message
displayMessage(message, 'user');
// Clear input and disable controls
input.value = '';
input.disabled = true;
button.disabled = true;
try {
const response = await fetch('http://localhost:8001/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ content: message, source: 'user' })
});
const data = await response.json();
if (!response.ok) {
// Handle error response
if (data.detail && data.detail.type === 'error') {
displayMessage(data.detail.content, 'error');
} else {
displayMessage('Error: ' + (data.detail || 'Unknown error'), 'error');
}
} else {
displayMessage(data.content, 'assistant');
}
} catch (error) {
console.error('Error:', error);
displayMessage('Error: Could not reach the server.', 'error');
} finally {
// Re-enable controls
input.disabled = false;
button.disabled = false;
input.focus();
}
}
function displayMessage(content, source) {
const messagesContainer = document.getElementById('messages');
const messageElement = document.createElement('div');
messageElement.className = `message ${source}`;
const labelElement = document.createElement('span');
labelElement.className = 'label';
labelElement.textContent = source;
const contentElement = document.createElement('div');
contentElement.className = 'content';
contentElement.textContent = content;
messageElement.appendChild(labelElement);
messageElement.appendChild(contentElement);
messagesContainer.appendChild(messageElement);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
async function loadHistory() {
try {
const response = await fetch('http://localhost:8001/history');
if (!response.ok) {
throw new Error('Network response was not ok');
}
const history = await response.json();
history.forEach(message => {
displayMessage(message.content, message.source);
});
} catch (error) {
console.error('Error loading history:', error);
}
}
// Load chat history when the page loads
window.onload = loadHistory;
</script>
</body>
</html>

View File

@ -0,0 +1,111 @@
import json
import os
from typing import Any
import aiofiles
import yaml
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Serve static files
app.mount("/static", StaticFiles(directory="."), name="static")
@app.get("/")
async def root():
"""Serve the chat interface HTML file."""
return FileResponse("app_agent.html")
model_config_path = "model_config.yaml"
state_path = "agent_state.json"
history_path = "agent_history.json"
async def get_agent() -> AssistantAgent:
"""Get the assistant agent, load state from file."""
# Get model client from config.
async with aiofiles.open(model_config_path, "r") as file:
model_config = yaml.safe_load(await file.read())
model_client = ChatCompletionClient.load_component(model_config)
# Create the assistant agent.
agent = AssistantAgent(
name="assistant",
model_client=model_client,
system_message="You are a helpful assistant.",
)
# Load state from file.
if not os.path.exists(state_path):
return agent # Return agent without loading state.
async with aiofiles.open(state_path, "r") as file:
state = json.loads(await file.read())
await agent.load_state(state)
return agent
async def get_history() -> list[dict[str, Any]]:
"""Get chat history from file."""
if not os.path.exists(history_path):
return []
async with aiofiles.open(history_path, "r") as file:
return json.loads(await file.read())
@app.get("/history")
async def history() -> list[dict[str, Any]]:
try:
return await get_history()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/chat", response_model=TextMessage)
async def chat(request: TextMessage) -> TextMessage:
try:
# Get the agent and respond to the message.
agent = await get_agent()
response = await agent.on_messages(messages=[request], cancellation_token=CancellationToken())
# Save agent state to file.
state = await agent.save_state()
async with aiofiles.open(state_path, "w") as file:
await file.write(json.dumps(state))
# Save chat history to file.
history = await get_history()
history.append(request.model_dump())
history.append(response.chat_message.model_dump())
async with aiofiles.open(history_path, "w") as file:
await file.write(json.dumps(history))
assert isinstance(response.chat_message, TextMessage)
return response.chat_message
except Exception as e:
error_message = {
"type": "error",
"content": f"Error: {str(e)}",
"source": "system"
}
raise HTTPException(status_code=500, detail=error_message) from e
# Example usage
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@ -0,0 +1,217 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AutoGen FastAPI Sample: Team</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100vh;
background-color: #f0f0f0;
}
#chat-container {
width: 90%;
max-width: 600px;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
padding: 20px;
}
#messages {
height: 600px;
overflow-y: auto;
border-bottom: 1px solid #ddd;
margin-bottom: 20px;
}
.message {
margin: 10px 0;
}
.message.user {
text-align: right;
}
.message.assistant {
text-align: left;
}
.label {
font-weight: bold;
display: block;
}
.content {
margin-top: 5px;
}
#input-container {
display: flex;
}
#input-container input {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
}
#input-container button {
padding: 10px 20px;
border: none;
background-color: #007bff;
color: #fff;
border-radius: 4px;
cursor: pointer;
}
#input-container input:disabled,
#input-container button:disabled {
background-color: #e0e0e0;
cursor: not-allowed;
}
.message.error {
color: #721c24;
background-color: #f8d7da;
border: 1px solid #f5c6cb;
padding: 10px;
border-radius: 4px;
margin: 10px 0;
}
.message.system {
color: #0c5460;
background-color: #d1ecf1;
border: 1px solid #bee5eb;
padding: 10px;
border-radius: 4px;
margin: 10px 0;
}
</style>
</head>
<body>
<div id="chat-container">
<div id="messages"></div>
<div id="input-container">
<input type="text" id="message-input" placeholder="Type a message...">
<button id="send-button" onclick="sendMessage()">Send</button>
</div>
</div>
<script>
const ws = new WebSocket('ws://localhost:8002/ws/chat');
ws.onmessage = function (event) {
const message = JSON.parse(event.data);
if (message.type === 'UserInputRequestedEvent') {
// Re-enable input and send button if UserInputRequestedEvent is received
enableInput();
}
else if (message.type === 'error') {
// Display error message
displayMessage(message.content, 'error');
enableInput();
}
else {
// Display regular message
displayMessage(message.content, message.source);
}
};
ws.onerror = function(error) {
displayMessage("WebSocket error occurred. Please refresh the page.", 'error');
enableInput();
};
ws.onclose = function() {
displayMessage("Connection closed. Please refresh the page.", 'system');
disableInput();
};
document.getElementById('message-input').addEventListener('keydown', function (event) {
if (event.key === 'Enter' && !event.target.disabled) {
sendMessage();
}
});
async function sendMessage() {
const input = document.getElementById('message-input');
const button = document.getElementById('send-button');
const message = input.value;
if (!message) return;
// Clear input and disable input and send button
input.value = '';
disableInput();
// Send message to WebSocket
ws.send(JSON.stringify({ content: message, source: 'user' }));
}
function displayMessage(content, source) {
const messagesContainer = document.getElementById('messages');
const messageElement = document.createElement('div');
messageElement.className = `message ${source}`;
const labelElement = document.createElement('span');
labelElement.className = 'label';
labelElement.textContent = source;
const contentElement = document.createElement('div');
contentElement.className = 'content';
contentElement.textContent = content;
messageElement.appendChild(labelElement);
messageElement.appendChild(contentElement);
messagesContainer.appendChild(messageElement);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
function disableInput() {
const input = document.getElementById('message-input');
const button = document.getElementById('send-button');
input.disabled = true;
button.disabled = true;
}
function enableInput() {
const input = document.getElementById('message-input');
const button = document.getElementById('send-button');
input.disabled = false;
button.disabled = false;
}
async function loadHistory() {
try {
const response = await fetch('http://localhost:8002/history');
if (!response.ok) {
throw new Error('Network response was not ok');
}
const history = await response.json();
history.forEach(message => {
displayMessage(message.content, message.source);
});
} catch (error) {
console.error('Error loading history:', error);
}
}
// Load chat history when the page loads
window.onload = loadHistory;
</script>
</body>
</html>

View File

@ -0,0 +1,166 @@
import json
import logging
import os
from typing import Any, Awaitable, Callable, Optional
import aiofiles
import yaml
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
model_config_path = "model_config.yaml"
state_path = "team_state.json"
history_path = "team_history.json"
# Serve static files
app.mount("/static", StaticFiles(directory="."), name="static")
@app.get("/")
async def root():
"""Serve the chat interface HTML file."""
return FileResponse("app_team.html")
async def get_team(
user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]],
) -> RoundRobinGroupChat:
# Get model client from config.
async with aiofiles.open(model_config_path, "r") as file:
model_config = yaml.safe_load(await file.read())
model_client = ChatCompletionClient.load_component(model_config)
# Create the team.
agent = AssistantAgent(
name="assistant",
model_client=model_client,
system_message="You are a helpful assistant.",
)
yoda = AssistantAgent(
name="yoda",
model_client=model_client,
system_message="Repeat the same message in the tone of Yoda.",
)
user_proxy = UserProxyAgent(
name="user",
input_func=user_input_func, # Use the user input function.
)
team = RoundRobinGroupChat(
[agent, yoda, user_proxy],
)
# Load state from file.
if not os.path.exists(state_path):
return team
async with aiofiles.open(state_path, "r") as file:
state = json.loads(await file.read())
await team.load_state(state)
return team
async def get_history() -> list[dict[str, Any]]:
"""Get chat history from file."""
if not os.path.exists(history_path):
return []
async with aiofiles.open(history_path, "r") as file:
return json.loads(await file.read())
@app.get("/history")
async def history() -> list[dict[str, Any]]:
try:
return await get_history()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.websocket("/ws/chat")
async def chat(websocket: WebSocket):
await websocket.accept()
# User input function used by the team.
async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:
data = await websocket.receive_json()
message = TextMessage.model_validate(data)
return message.content
try:
while True:
# Get user message.
data = await websocket.receive_json()
request = TextMessage.model_validate(data)
try:
# Get the team and respond to the message.
team = await get_team(_user_input)
history = await get_history()
stream = team.run_stream(task=request)
async for message in stream:
if isinstance(message, TaskResult):
continue
await websocket.send_json(message.model_dump())
if not isinstance(message, UserInputRequestedEvent):
# Don't save user input events to history.
history.append(message.model_dump())
# Save team state to file.
async with aiofiles.open(state_path, "w") as file:
state = await team.save_state()
await file.write(json.dumps(state))
# Save chat history to file.
async with aiofiles.open(history_path, "w") as file:
await file.write(json.dumps(history))
except Exception as e:
# Send error message to client
error_message = {
"type": "error",
"content": f"Error: {str(e)}",
"source": "system"
}
await websocket.send_json(error_message)
# Re-enable input after error
await websocket.send_json({
"type": "UserInputRequestedEvent",
"content": "An error occurred. Please try again.",
"source": "system"
})
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
try:
await websocket.send_json({
"type": "error",
"content": f"Unexpected error: {str(e)}",
"source": "system"
})
except:
pass
# Example usage
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@ -0,0 +1,26 @@
# Use Open AI with key
provider: autogen_ext.models.openai.OpenAIChatCompletionClient
config:
model: gpt-4o
api_key: REPLACE_WITH_YOUR_API_KEY
# Use Azure Open AI with key
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
# config:
# model: gpt-4o
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
# azure_deployment: {your-azure-deployment}
# api_version: {your-api-version}
# api_key: REPLACE_WITH_YOUR_API_KEY
# Use Azure OpenAI with AD token provider.
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
# config:
# model: gpt-4o
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
# azure_deployment: {your-azure-deployment}
# api_version: {your-api-version}
# azure_ad_token_provider:
# provider: autogen_ext.auth.azure.AzureTokenProvider
# config:
# provider_kind: DefaultAzureCredential
# scopes:
# - https://cognitiveservices.azure.com/.default