mirror of https://github.com/microsoft/autogen.git
Add sample chat application with FastAPI (#5433)
Introduce a sample chat application using AgentChat and FastAPI, demonstrating single-agent and team chat functionalities, along with state persistence and conversation history management. Resolves #5423 --------- Co-authored-by: Victor Dibia <victor.dibia@gmail.com> Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
parent
f20ba9127d
commit
abdc0da4f1
|
@ -0,0 +1,5 @@
|
|||
model_config.yaml
|
||||
agent_state.json
|
||||
agent_history.json
|
||||
team_state.json
|
||||
team_history.json
|
|
@ -0,0 +1,70 @@
|
|||
# AgentChat App with FastAPI
|
||||
|
||||
This sample demonstrates how to create a simple chat application using
|
||||
[AgentChat](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/index.html)
|
||||
and [FastAPI](https://fastapi.tiangolo.com/).
|
||||
|
||||
You will be using the following features of AgentChat:
|
||||
|
||||
1. Agent:
|
||||
- `AssistantAgent`
|
||||
- `UserProxyAgent` with a custom websocket input function
|
||||
2. Team: `RoundRobinGroupChat`
|
||||
3. State persistence: `save_state` and `load_state` methods of both agent and team.
|
||||
|
||||
## Setup
|
||||
|
||||
Install the required packages with OpenAI support:
|
||||
|
||||
```bash
|
||||
pip install -U "autogen-ext[openai]" "fastapi" "uvicorn" "PyYAML"
|
||||
```
|
||||
|
||||
To use models other than OpenAI, see the [Models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html) documentation.
|
||||
|
||||
Create a new file named `model_config.yaml` in the same directory as this README file to configure your model settings.
|
||||
See `model_config_template.yaml` for an example.
|
||||
|
||||
## Chat with a single agent
|
||||
|
||||
To start the FastAPI server for single-agent chat, run:
|
||||
|
||||
```bash
|
||||
python app_agent.py
|
||||
```
|
||||
|
||||
Visit http://localhost:8001 in your browser to start chatting.
|
||||
|
||||
## Chat with a team of agents
|
||||
|
||||
To start the FastAPI server for team chat, run:
|
||||
|
||||
```bash
|
||||
python app_team.py
|
||||
```
|
||||
|
||||
Visit http://localhost:8002 in your browser to start chatting.
|
||||
|
||||
The team also includes a `UserProxyAgent` agent with a custom websocket input function
|
||||
that allows the user to send messages to the team from the browser.
|
||||
|
||||
The team follows a round-robin strategy so each agent will take turns to respond.
|
||||
When it is the user's turn, the input box will be enabled.
|
||||
Once the user sends a message, the input box will be disabled and the agents
|
||||
will take turns to respond.
|
||||
|
||||
## State persistence
|
||||
|
||||
The agents and team use the `load_state` and `save_state` methods to load and save
|
||||
their state from and to files on each turn.
|
||||
For the agent, the state is saved to and loaded from `agent_state.json`.
|
||||
For the team, the state is saved to and loaded from `team_state.json`.
|
||||
You can inspect the state files to see the state of the agents and team
|
||||
once you have chatted with them.
|
||||
|
||||
When the server restarts, the agents and team will load their state from the state files
|
||||
to maintain their state across restarts.
|
||||
|
||||
Additionally, the apps uses separate JSON files,
|
||||
`agent_history.json` and `team_history.json`, to store the conversation history
|
||||
for display in the browser.
|
|
@ -0,0 +1,195 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AutoGen FastAPI Sample: Agent</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100vh;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
#chat-container {
|
||||
width: 90%;
|
||||
max-width: 600px;
|
||||
background-color: #fff;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
#messages {
|
||||
height: 600px;
|
||||
overflow-y: auto;
|
||||
border-bottom: 1px solid #ddd;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.message.user {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.message.assistant {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.message.error {
|
||||
color: #721c24;
|
||||
background-color: #f8d7da;
|
||||
border: 1px solid #f5c6cb;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.message.system {
|
||||
color: #0c5460;
|
||||
background-color: #d1ecf1;
|
||||
border: 1px solid #bee5eb;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
#input-container input:disabled,
|
||||
#input-container button:disabled {
|
||||
background-color: #e0e0e0;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#input-container {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
#input-container input {
|
||||
flex: 1;
|
||||
padding: 10px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
#input-container button {
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
background-color: #007bff;
|
||||
color: #fff;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="chat-container">
|
||||
<div id="messages"></div>
|
||||
<div id="input-container">
|
||||
<input type="text" id="message-input" placeholder="Type a message...">
|
||||
<button onclick="sendMessage()">Send</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.getElementById('message-input').addEventListener('keydown', function (event) {
|
||||
if (event.key === 'Enter') {
|
||||
sendMessage();
|
||||
}
|
||||
});
|
||||
|
||||
async function sendMessage() {
|
||||
const input = document.getElementById('message-input');
|
||||
const button = document.querySelector('#input-container button');
|
||||
const message = input.value;
|
||||
if (!message) return;
|
||||
|
||||
// Display user message
|
||||
displayMessage(message, 'user');
|
||||
|
||||
// Clear input and disable controls
|
||||
input.value = '';
|
||||
input.disabled = true;
|
||||
button.disabled = true;
|
||||
|
||||
try {
|
||||
const response = await fetch('http://localhost:8001/chat', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({ content: message, source: 'user' })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
if (!response.ok) {
|
||||
// Handle error response
|
||||
if (data.detail && data.detail.type === 'error') {
|
||||
displayMessage(data.detail.content, 'error');
|
||||
} else {
|
||||
displayMessage('Error: ' + (data.detail || 'Unknown error'), 'error');
|
||||
}
|
||||
} else {
|
||||
displayMessage(data.content, 'assistant');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
displayMessage('Error: Could not reach the server.', 'error');
|
||||
} finally {
|
||||
// Re-enable controls
|
||||
input.disabled = false;
|
||||
button.disabled = false;
|
||||
input.focus();
|
||||
}
|
||||
}
|
||||
|
||||
function displayMessage(content, source) {
|
||||
const messagesContainer = document.getElementById('messages');
|
||||
const messageElement = document.createElement('div');
|
||||
messageElement.className = `message ${source}`;
|
||||
|
||||
const labelElement = document.createElement('span');
|
||||
labelElement.className = 'label';
|
||||
labelElement.textContent = source;
|
||||
|
||||
const contentElement = document.createElement('div');
|
||||
contentElement.className = 'content';
|
||||
contentElement.textContent = content;
|
||||
|
||||
messageElement.appendChild(labelElement);
|
||||
messageElement.appendChild(contentElement);
|
||||
messagesContainer.appendChild(messageElement);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
async function loadHistory() {
|
||||
try {
|
||||
const response = await fetch('http://localhost:8001/history');
|
||||
if (!response.ok) {
|
||||
throw new Error('Network response was not ok');
|
||||
}
|
||||
const history = await response.json();
|
||||
history.forEach(message => {
|
||||
displayMessage(message.content, message.source);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error loading history:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Load chat history when the page loads
|
||||
window.onload = loadHistory;
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -0,0 +1,111 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.messages import TextMessage
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
# Serve static files
|
||||
app.mount("/static", StaticFiles(directory="."), name="static")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Serve the chat interface HTML file."""
|
||||
return FileResponse("app_agent.html")
|
||||
|
||||
model_config_path = "model_config.yaml"
|
||||
state_path = "agent_state.json"
|
||||
history_path = "agent_history.json"
|
||||
|
||||
|
||||
async def get_agent() -> AssistantAgent:
|
||||
"""Get the assistant agent, load state from file."""
|
||||
# Get model client from config.
|
||||
async with aiofiles.open(model_config_path, "r") as file:
|
||||
model_config = yaml.safe_load(await file.read())
|
||||
model_client = ChatCompletionClient.load_component(model_config)
|
||||
# Create the assistant agent.
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
system_message="You are a helpful assistant.",
|
||||
)
|
||||
# Load state from file.
|
||||
if not os.path.exists(state_path):
|
||||
return agent # Return agent without loading state.
|
||||
async with aiofiles.open(state_path, "r") as file:
|
||||
state = json.loads(await file.read())
|
||||
await agent.load_state(state)
|
||||
return agent
|
||||
|
||||
|
||||
async def get_history() -> list[dict[str, Any]]:
|
||||
"""Get chat history from file."""
|
||||
if not os.path.exists(history_path):
|
||||
return []
|
||||
async with aiofiles.open(history_path, "r") as file:
|
||||
return json.loads(await file.read())
|
||||
|
||||
|
||||
@app.get("/history")
|
||||
async def history() -> list[dict[str, Any]]:
|
||||
try:
|
||||
return await get_history()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@app.post("/chat", response_model=TextMessage)
|
||||
async def chat(request: TextMessage) -> TextMessage:
|
||||
try:
|
||||
# Get the agent and respond to the message.
|
||||
agent = await get_agent()
|
||||
response = await agent.on_messages(messages=[request], cancellation_token=CancellationToken())
|
||||
|
||||
# Save agent state to file.
|
||||
state = await agent.save_state()
|
||||
async with aiofiles.open(state_path, "w") as file:
|
||||
await file.write(json.dumps(state))
|
||||
|
||||
# Save chat history to file.
|
||||
history = await get_history()
|
||||
history.append(request.model_dump())
|
||||
history.append(response.chat_message.model_dump())
|
||||
async with aiofiles.open(history_path, "w") as file:
|
||||
await file.write(json.dumps(history))
|
||||
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
return response.chat_message
|
||||
except Exception as e:
|
||||
error_message = {
|
||||
"type": "error",
|
||||
"content": f"Error: {str(e)}",
|
||||
"source": "system"
|
||||
}
|
||||
raise HTTPException(status_code=500, detail=error_message) from e
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
@ -0,0 +1,217 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AutoGen FastAPI Sample: Team</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100vh;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
#chat-container {
|
||||
width: 90%;
|
||||
max-width: 600px;
|
||||
background-color: #fff;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
#messages {
|
||||
height: 600px;
|
||||
overflow-y: auto;
|
||||
border-bottom: 1px solid #ddd;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.message.user {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.message.assistant {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.label {
|
||||
font-weight: bold;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.content {
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
#input-container {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
#input-container input {
|
||||
flex: 1;
|
||||
padding: 10px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
#input-container button {
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
background-color: #007bff;
|
||||
color: #fff;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#input-container input:disabled,
|
||||
#input-container button:disabled {
|
||||
background-color: #e0e0e0;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.message.error {
|
||||
color: #721c24;
|
||||
background-color: #f8d7da;
|
||||
border: 1px solid #f5c6cb;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.message.system {
|
||||
color: #0c5460;
|
||||
background-color: #d1ecf1;
|
||||
border: 1px solid #bee5eb;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="chat-container">
|
||||
<div id="messages"></div>
|
||||
<div id="input-container">
|
||||
<input type="text" id="message-input" placeholder="Type a message...">
|
||||
<button id="send-button" onclick="sendMessage()">Send</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const ws = new WebSocket('ws://localhost:8002/ws/chat');
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
const message = JSON.parse(event.data);
|
||||
|
||||
if (message.type === 'UserInputRequestedEvent') {
|
||||
// Re-enable input and send button if UserInputRequestedEvent is received
|
||||
enableInput();
|
||||
}
|
||||
else if (message.type === 'error') {
|
||||
// Display error message
|
||||
displayMessage(message.content, 'error');
|
||||
enableInput();
|
||||
}
|
||||
else {
|
||||
// Display regular message
|
||||
displayMessage(message.content, message.source);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
displayMessage("WebSocket error occurred. Please refresh the page.", 'error');
|
||||
enableInput();
|
||||
};
|
||||
|
||||
ws.onclose = function() {
|
||||
displayMessage("Connection closed. Please refresh the page.", 'system');
|
||||
disableInput();
|
||||
};
|
||||
|
||||
document.getElementById('message-input').addEventListener('keydown', function (event) {
|
||||
if (event.key === 'Enter' && !event.target.disabled) {
|
||||
sendMessage();
|
||||
}
|
||||
});
|
||||
|
||||
async function sendMessage() {
|
||||
const input = document.getElementById('message-input');
|
||||
const button = document.getElementById('send-button');
|
||||
const message = input.value;
|
||||
if (!message) return;
|
||||
|
||||
// Clear input and disable input and send button
|
||||
input.value = '';
|
||||
disableInput();
|
||||
|
||||
// Send message to WebSocket
|
||||
ws.send(JSON.stringify({ content: message, source: 'user' }));
|
||||
}
|
||||
|
||||
function displayMessage(content, source) {
|
||||
const messagesContainer = document.getElementById('messages');
|
||||
const messageElement = document.createElement('div');
|
||||
messageElement.className = `message ${source}`;
|
||||
|
||||
const labelElement = document.createElement('span');
|
||||
labelElement.className = 'label';
|
||||
labelElement.textContent = source;
|
||||
|
||||
const contentElement = document.createElement('div');
|
||||
contentElement.className = 'content';
|
||||
contentElement.textContent = content;
|
||||
|
||||
messageElement.appendChild(labelElement);
|
||||
messageElement.appendChild(contentElement);
|
||||
messagesContainer.appendChild(messageElement);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
function disableInput() {
|
||||
const input = document.getElementById('message-input');
|
||||
const button = document.getElementById('send-button');
|
||||
input.disabled = true;
|
||||
button.disabled = true;
|
||||
}
|
||||
|
||||
function enableInput() {
|
||||
const input = document.getElementById('message-input');
|
||||
const button = document.getElementById('send-button');
|
||||
input.disabled = false;
|
||||
button.disabled = false;
|
||||
}
|
||||
|
||||
async function loadHistory() {
|
||||
try {
|
||||
const response = await fetch('http://localhost:8002/history');
|
||||
if (!response.ok) {
|
||||
throw new Error('Network response was not ok');
|
||||
}
|
||||
const history = await response.json();
|
||||
history.forEach(message => {
|
||||
displayMessage(message.content, message.source);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error loading history:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Load chat history when the page loads
|
||||
window.onload = loadHistory;
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -0,0 +1,166 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
model_config_path = "model_config.yaml"
|
||||
state_path = "team_state.json"
|
||||
history_path = "team_history.json"
|
||||
|
||||
# Serve static files
|
||||
app.mount("/static", StaticFiles(directory="."), name="static")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Serve the chat interface HTML file."""
|
||||
return FileResponse("app_team.html")
|
||||
|
||||
|
||||
async def get_team(
|
||||
user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]],
|
||||
) -> RoundRobinGroupChat:
|
||||
# Get model client from config.
|
||||
async with aiofiles.open(model_config_path, "r") as file:
|
||||
model_config = yaml.safe_load(await file.read())
|
||||
model_client = ChatCompletionClient.load_component(model_config)
|
||||
# Create the team.
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
system_message="You are a helpful assistant.",
|
||||
)
|
||||
yoda = AssistantAgent(
|
||||
name="yoda",
|
||||
model_client=model_client,
|
||||
system_message="Repeat the same message in the tone of Yoda.",
|
||||
)
|
||||
user_proxy = UserProxyAgent(
|
||||
name="user",
|
||||
input_func=user_input_func, # Use the user input function.
|
||||
)
|
||||
team = RoundRobinGroupChat(
|
||||
[agent, yoda, user_proxy],
|
||||
)
|
||||
# Load state from file.
|
||||
if not os.path.exists(state_path):
|
||||
return team
|
||||
async with aiofiles.open(state_path, "r") as file:
|
||||
state = json.loads(await file.read())
|
||||
await team.load_state(state)
|
||||
return team
|
||||
|
||||
|
||||
async def get_history() -> list[dict[str, Any]]:
|
||||
"""Get chat history from file."""
|
||||
if not os.path.exists(history_path):
|
||||
return []
|
||||
async with aiofiles.open(history_path, "r") as file:
|
||||
return json.loads(await file.read())
|
||||
|
||||
|
||||
@app.get("/history")
|
||||
async def history() -> list[dict[str, Any]]:
|
||||
try:
|
||||
return await get_history()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@app.websocket("/ws/chat")
|
||||
async def chat(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
# User input function used by the team.
|
||||
async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:
|
||||
data = await websocket.receive_json()
|
||||
message = TextMessage.model_validate(data)
|
||||
return message.content
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user message.
|
||||
data = await websocket.receive_json()
|
||||
request = TextMessage.model_validate(data)
|
||||
|
||||
try:
|
||||
# Get the team and respond to the message.
|
||||
team = await get_team(_user_input)
|
||||
history = await get_history()
|
||||
stream = team.run_stream(task=request)
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
continue
|
||||
await websocket.send_json(message.model_dump())
|
||||
if not isinstance(message, UserInputRequestedEvent):
|
||||
# Don't save user input events to history.
|
||||
history.append(message.model_dump())
|
||||
|
||||
# Save team state to file.
|
||||
async with aiofiles.open(state_path, "w") as file:
|
||||
state = await team.save_state()
|
||||
await file.write(json.dumps(state))
|
||||
|
||||
# Save chat history to file.
|
||||
async with aiofiles.open(history_path, "w") as file:
|
||||
await file.write(json.dumps(history))
|
||||
|
||||
except Exception as e:
|
||||
# Send error message to client
|
||||
error_message = {
|
||||
"type": "error",
|
||||
"content": f"Error: {str(e)}",
|
||||
"source": "system"
|
||||
}
|
||||
await websocket.send_json(error_message)
|
||||
# Re-enable input after error
|
||||
await websocket.send_json({
|
||||
"type": "UserInputRequestedEvent",
|
||||
"content": "An error occurred. Please try again.",
|
||||
"source": "system"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"content": f"Unexpected error: {str(e)}",
|
||||
"source": "system"
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
|
@ -0,0 +1,26 @@
|
|||
# Use Open AI with key
|
||||
provider: autogen_ext.models.openai.OpenAIChatCompletionClient
|
||||
config:
|
||||
model: gpt-4o
|
||||
api_key: REPLACE_WITH_YOUR_API_KEY
|
||||
# Use Azure Open AI with key
|
||||
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
|
||||
# config:
|
||||
# model: gpt-4o
|
||||
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
|
||||
# azure_deployment: {your-azure-deployment}
|
||||
# api_version: {your-api-version}
|
||||
# api_key: REPLACE_WITH_YOUR_API_KEY
|
||||
# Use Azure OpenAI with AD token provider.
|
||||
# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient
|
||||
# config:
|
||||
# model: gpt-4o
|
||||
# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/
|
||||
# azure_deployment: {your-azure-deployment}
|
||||
# api_version: {your-api-version}
|
||||
# azure_ad_token_provider:
|
||||
# provider: autogen_ext.auth.azure.AzureTokenProvider
|
||||
# config:
|
||||
# provider_kind: DefaultAzureCredential
|
||||
# scopes:
|
||||
# - https://cognitiveservices.azure.com/.default
|
Loading…
Reference in New Issue