mirror of https://github.com/microsoft/autogen.git
Add sample chat application with FastAPI (#5433)
Introduce a sample chat application using AgentChat and FastAPI, demonstrating single-agent and team chat functionalities, along with state persistence and conversation history management. Resolves #5423 --------- Co-authored-by: Victor Dibia <victor.dibia@gmail.com> Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
parent
f20ba9127d
commit
abdc0da4f1
|
@ -0,0 +1,5 @@
|
||||||
|
model_config.yaml
|
||||||
|
agent_state.json
|
||||||
|
agent_history.json
|
||||||
|
team_state.json
|
||||||
|
team_history.json
|
|
@ -0,0 +1,70 @@
|
||||||
|
# AgentChat App with FastAPI
|
||||||
|
|
||||||
|
This sample demonstrates how to create a simple chat application using
|
||||||
|
[AgentChat](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/index.html)
|
||||||
|
and [FastAPI](https://fastapi.tiangolo.com/).
|
||||||
|
|
||||||
|
You will be using the following features of AgentChat:
|
||||||
|
|
||||||
|
1. Agent:
|
||||||
|
- `AssistantAgent`
|
||||||
|
- `UserProxyAgent` with a custom websocket input function
|
||||||
|
2. Team: `RoundRobinGroupChat`
|
||||||
|
3. State persistence: `save_state` and `load_state` methods of both agent and team.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Install the required packages with OpenAI support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U "autogen-ext[openai]" "fastapi" "uvicorn" "PyYAML"
|
||||||
|
```
|
||||||
|
|
||||||
|
To use models other than OpenAI, see the [Models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html) documentation.
|
||||||
|
|
||||||
|
Create a new file named `model_config.yaml` in the same directory as this README file to configure your model settings.
|
||||||
|
See `model_config_template.yaml` for an example.
|
||||||
|
|
||||||
|
## Chat with a single agent
|
||||||
|
|
||||||
|
To start the FastAPI server for single-agent chat, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app_agent.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Visit http://localhost:8001 in your browser to start chatting.
|
||||||
|
|
||||||
|
## Chat with a team of agents
|
||||||
|
|
||||||
|
To start the FastAPI server for team chat, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app_team.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Visit http://localhost:8002 in your browser to start chatting.
|
||||||
|
|
||||||
|
The team also includes a `UserProxyAgent` agent with a custom websocket input function
|
||||||
|
that allows the user to send messages to the team from the browser.
|
||||||
|
|
||||||
|
The team follows a round-robin strategy so each agent will take turns to respond.
|
||||||
|
When it is the user's turn, the input box will be enabled.
|
||||||
|
Once the user sends a message, the input box will be disabled and the agents
|
||||||
|
will take turns to respond.
|
||||||
|
|
||||||
|
## State persistence
|
||||||
|
|
||||||
|
The agents and team use the `load_state` and `save_state` methods to load and save
|
||||||
|
their state from and to files on each turn.
|
||||||
|
For the agent, the state is saved to and loaded from `agent_state.json`.
|
||||||
|
For the team, the state is saved to and loaded from `team_state.json`.
|
||||||
|
You can inspect the state files to see the state of the agents and team
|
||||||
|
once you have chatted with them.
|
||||||
|
|
||||||
|
When the server restarts, the agents and team will load their state from the state files
|
||||||
|
to maintain their state across restarts.
|
||||||
|
|
||||||
|
Additionally, the apps uses separate JSON files,
|
||||||
|
`agent_history.json` and `team_history.json`, to store the conversation history
|
||||||
|
for display in the browser.
|
|
@ -0,0 +1,195 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>AutoGen FastAPI Sample: Agent</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
height: 100vh;
|
||||||
|
background-color: #f0f0f0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-container {
|
||||||
|
width: 90%;
|
||||||
|
max-width: 600px;
|
||||||
|
background-color: #fff;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#messages {
|
||||||
|
height: 600px;
|
||||||
|
overflow-y: auto;
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.user {
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.assistant {
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.error {
|
||||||
|
color: #721c24;
|
||||||
|
background-color: #f8d7da;
|
||||||
|
border: 1px solid #f5c6cb;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.system {
|
||||||
|
color: #0c5460;
|
||||||
|
background-color: #d1ecf1;
|
||||||
|
border: 1px solid #bee5eb;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container input:disabled,
|
||||||
|
#input-container button:disabled {
|
||||||
|
background-color: #e0e0e0;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container input {
|
||||||
|
flex: 1;
|
||||||
|
padding: 10px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container button {
|
||||||
|
padding: 10px 20px;
|
||||||
|
border: none;
|
||||||
|
background-color: #007bff;
|
||||||
|
color: #fff;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="chat-container">
|
||||||
|
<div id="messages"></div>
|
||||||
|
<div id="input-container">
|
||||||
|
<input type="text" id="message-input" placeholder="Type a message...">
|
||||||
|
<button onclick="sendMessage()">Send</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
document.getElementById('message-input').addEventListener('keydown', function (event) {
|
||||||
|
if (event.key === 'Enter') {
|
||||||
|
sendMessage();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
async function sendMessage() {
|
||||||
|
const input = document.getElementById('message-input');
|
||||||
|
const button = document.querySelector('#input-container button');
|
||||||
|
const message = input.value;
|
||||||
|
if (!message) return;
|
||||||
|
|
||||||
|
// Display user message
|
||||||
|
displayMessage(message, 'user');
|
||||||
|
|
||||||
|
// Clear input and disable controls
|
||||||
|
input.value = '';
|
||||||
|
input.disabled = true;
|
||||||
|
button.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:8001/chat', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ content: message, source: 'user' })
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (!response.ok) {
|
||||||
|
// Handle error response
|
||||||
|
if (data.detail && data.detail.type === 'error') {
|
||||||
|
displayMessage(data.detail.content, 'error');
|
||||||
|
} else {
|
||||||
|
displayMessage('Error: ' + (data.detail || 'Unknown error'), 'error');
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
displayMessage(data.content, 'assistant');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error:', error);
|
||||||
|
displayMessage('Error: Could not reach the server.', 'error');
|
||||||
|
} finally {
|
||||||
|
// Re-enable controls
|
||||||
|
input.disabled = false;
|
||||||
|
button.disabled = false;
|
||||||
|
input.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function displayMessage(content, source) {
|
||||||
|
const messagesContainer = document.getElementById('messages');
|
||||||
|
const messageElement = document.createElement('div');
|
||||||
|
messageElement.className = `message ${source}`;
|
||||||
|
|
||||||
|
const labelElement = document.createElement('span');
|
||||||
|
labelElement.className = 'label';
|
||||||
|
labelElement.textContent = source;
|
||||||
|
|
||||||
|
const contentElement = document.createElement('div');
|
||||||
|
contentElement.className = 'content';
|
||||||
|
contentElement.textContent = content;
|
||||||
|
|
||||||
|
messageElement.appendChild(labelElement);
|
||||||
|
messageElement.appendChild(contentElement);
|
||||||
|
messagesContainer.appendChild(messageElement);
|
||||||
|
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadHistory() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:8001/history');
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Network response was not ok');
|
||||||
|
}
|
||||||
|
const history = await response.json();
|
||||||
|
history.forEach(message => {
|
||||||
|
displayMessage(message.content, message.source);
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading history:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load chat history when the page loads
|
||||||
|
window.onload = loadHistory;
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
|
@ -0,0 +1,111 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import yaml
|
||||||
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.messages import TextMessage
|
||||||
|
from autogen_core import CancellationToken
|
||||||
|
from autogen_core.models import ChatCompletionClient
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allows all origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all methods
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serve static files
|
||||||
|
app.mount("/static", StaticFiles(directory="."), name="static")
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Serve the chat interface HTML file."""
|
||||||
|
return FileResponse("app_agent.html")
|
||||||
|
|
||||||
|
model_config_path = "model_config.yaml"
|
||||||
|
state_path = "agent_state.json"
|
||||||
|
history_path = "agent_history.json"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent() -> AssistantAgent:
|
||||||
|
"""Get the assistant agent, load state from file."""
|
||||||
|
# Get model client from config.
|
||||||
|
async with aiofiles.open(model_config_path, "r") as file:
|
||||||
|
model_config = yaml.safe_load(await file.read())
|
||||||
|
model_client = ChatCompletionClient.load_component(model_config)
|
||||||
|
# Create the assistant agent.
|
||||||
|
agent = AssistantAgent(
|
||||||
|
name="assistant",
|
||||||
|
model_client=model_client,
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
)
|
||||||
|
# Load state from file.
|
||||||
|
if not os.path.exists(state_path):
|
||||||
|
return agent # Return agent without loading state.
|
||||||
|
async with aiofiles.open(state_path, "r") as file:
|
||||||
|
state = json.loads(await file.read())
|
||||||
|
await agent.load_state(state)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
async def get_history() -> list[dict[str, Any]]:
|
||||||
|
"""Get chat history from file."""
|
||||||
|
if not os.path.exists(history_path):
|
||||||
|
return []
|
||||||
|
async with aiofiles.open(history_path, "r") as file:
|
||||||
|
return json.loads(await file.read())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/history")
|
||||||
|
async def history() -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
return await get_history()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat", response_model=TextMessage)
|
||||||
|
async def chat(request: TextMessage) -> TextMessage:
|
||||||
|
try:
|
||||||
|
# Get the agent and respond to the message.
|
||||||
|
agent = await get_agent()
|
||||||
|
response = await agent.on_messages(messages=[request], cancellation_token=CancellationToken())
|
||||||
|
|
||||||
|
# Save agent state to file.
|
||||||
|
state = await agent.save_state()
|
||||||
|
async with aiofiles.open(state_path, "w") as file:
|
||||||
|
await file.write(json.dumps(state))
|
||||||
|
|
||||||
|
# Save chat history to file.
|
||||||
|
history = await get_history()
|
||||||
|
history.append(request.model_dump())
|
||||||
|
history.append(response.chat_message.model_dump())
|
||||||
|
async with aiofiles.open(history_path, "w") as file:
|
||||||
|
await file.write(json.dumps(history))
|
||||||
|
|
||||||
|
assert isinstance(response.chat_message, TextMessage)
|
||||||
|
return response.chat_message
|
||||||
|
except Exception as e:
|
||||||
|
error_message = {
|
||||||
|
"type": "error",
|
||||||
|
"content": f"Error: {str(e)}",
|
||||||
|
"source": "system"
|
||||||
|
}
|
||||||
|
raise HTTPException(status_code=500, detail=error_message) from e
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
@ -0,0 +1,217 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>AutoGen FastAPI Sample: Team</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
height: 100vh;
|
||||||
|
background-color: #f0f0f0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-container {
|
||||||
|
width: 90%;
|
||||||
|
max-width: 600px;
|
||||||
|
background-color: #fff;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#messages {
|
||||||
|
height: 600px;
|
||||||
|
overflow-y: auto;
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.user {
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.assistant {
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
.label {
|
||||||
|
font-weight: bold;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.content {
|
||||||
|
margin-top: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container input {
|
||||||
|
flex: 1;
|
||||||
|
padding: 10px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container button {
|
||||||
|
padding: 10px 20px;
|
||||||
|
border: none;
|
||||||
|
background-color: #007bff;
|
||||||
|
color: #fff;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-container input:disabled,
|
||||||
|
#input-container button:disabled {
|
||||||
|
background-color: #e0e0e0;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.error {
|
||||||
|
color: #721c24;
|
||||||
|
background-color: #f8d7da;
|
||||||
|
border: 1px solid #f5c6cb;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.system {
|
||||||
|
color: #0c5460;
|
||||||
|
background-color: #d1ecf1;
|
||||||
|
border: 1px solid #bee5eb;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="chat-container">
|
||||||
|
<div id="messages"></div>
|
||||||
|
<div id="input-container">
|
||||||
|
<input type="text" id="message-input" placeholder="Type a message...">
|
||||||
|
<button id="send-button" onclick="sendMessage()">Send</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const ws = new WebSocket('ws://localhost:8002/ws/chat');
|
||||||
|
|
||||||
|
ws.onmessage = function (event) {
|
||||||
|
const message = JSON.parse(event.data);
|
||||||
|
|
||||||
|
if (message.type === 'UserInputRequestedEvent') {
|
||||||
|
// Re-enable input and send button if UserInputRequestedEvent is received
|
||||||
|
enableInput();
|
||||||
|
}
|
||||||
|
else if (message.type === 'error') {
|
||||||
|
// Display error message
|
||||||
|
displayMessage(message.content, 'error');
|
||||||
|
enableInput();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Display regular message
|
||||||
|
displayMessage(message.content, message.source);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = function(error) {
|
||||||
|
displayMessage("WebSocket error occurred. Please refresh the page.", 'error');
|
||||||
|
enableInput();
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = function() {
|
||||||
|
displayMessage("Connection closed. Please refresh the page.", 'system');
|
||||||
|
disableInput();
|
||||||
|
};
|
||||||
|
|
||||||
|
document.getElementById('message-input').addEventListener('keydown', function (event) {
|
||||||
|
if (event.key === 'Enter' && !event.target.disabled) {
|
||||||
|
sendMessage();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
async function sendMessage() {
|
||||||
|
const input = document.getElementById('message-input');
|
||||||
|
const button = document.getElementById('send-button');
|
||||||
|
const message = input.value;
|
||||||
|
if (!message) return;
|
||||||
|
|
||||||
|
// Clear input and disable input and send button
|
||||||
|
input.value = '';
|
||||||
|
disableInput();
|
||||||
|
|
||||||
|
// Send message to WebSocket
|
||||||
|
ws.send(JSON.stringify({ content: message, source: 'user' }));
|
||||||
|
}
|
||||||
|
|
||||||
|
function displayMessage(content, source) {
|
||||||
|
const messagesContainer = document.getElementById('messages');
|
||||||
|
const messageElement = document.createElement('div');
|
||||||
|
messageElement.className = `message ${source}`;
|
||||||
|
|
||||||
|
const labelElement = document.createElement('span');
|
||||||
|
labelElement.className = 'label';
|
||||||
|
labelElement.textContent = source;
|
||||||
|
|
||||||
|
const contentElement = document.createElement('div');
|
||||||
|
contentElement.className = 'content';
|
||||||
|
contentElement.textContent = content;
|
||||||
|
|
||||||
|
messageElement.appendChild(labelElement);
|
||||||
|
messageElement.appendChild(contentElement);
|
||||||
|
messagesContainer.appendChild(messageElement);
|
||||||
|
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function disableInput() {
|
||||||
|
const input = document.getElementById('message-input');
|
||||||
|
const button = document.getElementById('send-button');
|
||||||
|
input.disabled = true;
|
||||||
|
button.disabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
function enableInput() {
|
||||||
|
const input = document.getElementById('message-input');
|
||||||
|
const button = document.getElementById('send-button');
|
||||||
|
input.disabled = false;
|
||||||
|
button.disabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadHistory() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:8002/history');
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Network response was not ok');
|
||||||
|
}
|
||||||
|
const history = await response.json();
|
||||||
|
history.forEach(message => {
|
||||||
|
displayMessage(message.content, message.source);
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading history:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load chat history when the page loads
|
||||||
|
window.onload = loadHistory;
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
|
@ -0,0 +1,166 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import yaml
|
||||||
|
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
||||||
|
from autogen_agentchat.base import TaskResult
|
||||||
|
from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent
|
||||||
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||||
|
from autogen_core import CancellationToken
|
||||||
|
from autogen_core.models import ChatCompletionClient
|
||||||
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allows all origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all methods
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config_path = "model_config.yaml"
|
||||||
|
state_path = "team_state.json"
|
||||||
|
history_path = "team_history.json"
|
||||||
|
|
||||||
|
# Serve static files
|
||||||
|
app.mount("/static", StaticFiles(directory="."), name="static")
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Serve the chat interface HTML file."""
|
||||||
|
return FileResponse("app_team.html")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_team(
|
||||||
|
user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]],
|
||||||
|
) -> RoundRobinGroupChat:
|
||||||
|
# Get model client from config.
|
||||||
|
async with aiofiles.open(model_config_path, "r") as file:
|
||||||
|
model_config = yaml.safe_load(await file.read())
|
||||||
|
model_client = ChatCompletionClient.load_component(model_config)
|
||||||
|
# Create the team.
|
||||||
|
agent = AssistantAgent(
|
||||||
|
name="assistant",
|
||||||
|
model_client=model_client,
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
)
|
||||||
|
yoda = AssistantAgent(
|
||||||
|
name="yoda",
|
||||||
|
model_client=model_client,
|
||||||
|
system_message="Repeat the same message in the tone of Yoda.",
|
||||||
|
)
|
||||||
|
user_proxy = UserProxyAgent(
|
||||||
|
name="user",
|
||||||
|
input_func=user_input_func, # Use the user input function.
|
||||||
|
)
|
||||||
|
team = RoundRobinGroupChat(
|
||||||
|
[agent, yoda, user_proxy],
|
||||||
|
)
|
||||||
|
# Load state from file.
|
||||||
|
if not os.path.exists(state_path):
|
||||||
|
return team
|
||||||
|
async with aiofiles.open(state_path, "r") as file:
|
||||||
|
state = json.loads(await file.read())
|
||||||
|
await team.load_state(state)
|
||||||
|
return team
|
||||||
|
|
||||||
|
|
||||||
|
async def get_history() -> list[dict[str, Any]]:
|
||||||
|
"""Get chat history from file."""
|
||||||
|
if not os.path.exists(history_path):
|
||||||
|
return []
|
||||||
|
async with aiofiles.open(history_path, "r") as file:
|
||||||
|
return json.loads(await file.read())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/history")
|
||||||
|
async def history() -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
return await get_history()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/chat")
|
||||||
|
async def chat(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# User input function used by the team.
|
||||||
|
async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
message = TextMessage.model_validate(data)
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get user message.
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
request = TextMessage.model_validate(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the team and respond to the message.
|
||||||
|
team = await get_team(_user_input)
|
||||||
|
history = await get_history()
|
||||||
|
stream = team.run_stream(task=request)
|
||||||
|
async for message in stream:
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
continue
|
||||||
|
await websocket.send_json(message.model_dump())
|
||||||
|
if not isinstance(message, UserInputRequestedEvent):
|
||||||
|
# Don't save user input events to history.
|
||||||
|
history.append(message.model_dump())
|
||||||
|
|
||||||
|
# Save team state to file.
|
||||||
|
async with aiofiles.open(state_path, "w") as file:
|
||||||
|
state = await team.save_state()
|
||||||
|
await file.write(json.dumps(state))
|
||||||
|
|
||||||
|
# Save chat history to file.
|
||||||
|
async with aiofiles.open(history_path, "w") as file:
|
||||||
|
await file.write(json.dumps(history))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Send error message to client
|
||||||
|
error_message = {
|
||||||
|
"type": "error",
|
||||||
|
"content": f"Error: {str(e)}",
|
||||||
|
"source": "system"
|
||||||
|
}
|
||||||
|
await websocket.send_json(error_message)
|
||||||
|
# Re-enable input after error
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "UserInputRequestedEvent",
|
||||||
|
"content": "An error occurred. Please try again.",
|
||||||
|
"source": "system"
|
||||||
|
})
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("Client disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {str(e)}")
|
||||||
|
try:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"content": f"Unexpected error: {str(e)}",
|
||||||
|
"source": "system"
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Use Open AI with key
|
||||||
|
provider: autogen_ext.models.openai.OpenAIChatCompletionClient
|
||||||
|
config:
|
||||||
|
model: gpt-4o
|
||||||
|
api_key: REPLACE_WITH_YOUR_API_KEY
|
||||||
|
# Use Azure Open AI with key
|
||||||
|
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
|
||||||
|
# config:
|
||||||
|
# model: gpt-4o
|
||||||
|
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
|
||||||
|
# azure_deployment: {your-azure-deployment}
|
||||||
|
# api_version: {your-api-version}
|
||||||
|
# api_key: REPLACE_WITH_YOUR_API_KEY
|
||||||
|
# Use Azure OpenAI with AD token provider.
|
||||||
|
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
|
||||||
|
# config:
|
||||||
|
# model: gpt-4o
|
||||||
|
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
|
||||||
|
# azure_deployment: {your-azure-deployment}
|
||||||
|
# api_version: {your-api-version}
|
||||||
|
# azure_ad_token_provider:
|
||||||
|
# provider: autogen_ext.auth.azure.AzureTokenProvider
|
||||||
|
# config:
|
||||||
|
# provider_kind: DefaultAzureCredential
|
||||||
|
# scopes:
|
||||||
|
# - https://cognitiveservices.azure.com/.default
|
Loading…
Reference in New Issue