mirror of https://github.com/0xplaygrounds/rig
285 lines
8.9 KiB
Rust
285 lines
8.9 KiB
Rust
use rig::{
|
|
agent::Agent,
|
|
completion::{self, Completion, PromptError, ToolDefinition},
|
|
message::{AssistantContent, Message, ToolCall, ToolFunction, ToolResultContent, UserContent},
|
|
providers::anthropic,
|
|
tool::Tool,
|
|
OneOrMany,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
|
|
struct MultiTurnAgent<M: rig::completion::CompletionModel> {
|
|
agent: Agent<M>,
|
|
chat_history: Vec<completion::Message>,
|
|
}
|
|
|
|
impl<M: rig::completion::CompletionModel> MultiTurnAgent<M> {
|
|
async fn multi_turn_prompt(
|
|
&mut self,
|
|
prompt: impl Into<Message> + Send,
|
|
) -> Result<String, PromptError> {
|
|
let mut current_prompt: Message = prompt.into();
|
|
loop {
|
|
println!("Current Prompt: {:?}\n", current_prompt);
|
|
let resp = self
|
|
.agent
|
|
.completion(current_prompt.clone(), self.chat_history.clone())
|
|
.await?
|
|
.send()
|
|
.await?;
|
|
|
|
let mut final_text = None;
|
|
|
|
for content in resp.choice.into_iter() {
|
|
match content {
|
|
AssistantContent::Text(text) => {
|
|
println!("Intermediate Response: {:?}\n", text.text);
|
|
final_text = Some(text.text.clone());
|
|
self.chat_history.push(current_prompt.clone());
|
|
let response_message = Message::Assistant {
|
|
content: OneOrMany::one(AssistantContent::text(&text.text)),
|
|
};
|
|
self.chat_history.push(response_message);
|
|
}
|
|
AssistantContent::ToolCall(content) => {
|
|
self.chat_history.push(current_prompt.clone());
|
|
let tool_call_msg = AssistantContent::ToolCall(content.clone());
|
|
println!("Tool Call Msg: {:?}\n", tool_call_msg);
|
|
|
|
self.chat_history.push(Message::Assistant {
|
|
content: OneOrMany::one(tool_call_msg),
|
|
});
|
|
|
|
let ToolCall {
|
|
id,
|
|
function: ToolFunction { name, arguments },
|
|
} = content;
|
|
|
|
let tool_result =
|
|
self.agent.tools.call(&name, arguments.to_string()).await?;
|
|
|
|
current_prompt = Message::User {
|
|
content: OneOrMany::one(UserContent::tool_result(
|
|
id,
|
|
OneOrMany::one(ToolResultContent::text(tool_result)),
|
|
)),
|
|
};
|
|
|
|
final_text = None;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(text) = final_text {
|
|
return Ok(text);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
// tracing_subscriber::registry()
|
|
// .with(
|
|
// tracing_subscriber::EnvFilter::try_from_default_env()
|
|
// .unwrap_or_else(|_| "stdout=info".into()),
|
|
// )
|
|
// .with(tracing_subscriber::fmt::layer())
|
|
// .init();
|
|
|
|
// Create OpenAI client
|
|
let openai_client = anthropic::Client::from_env();
|
|
|
|
// Create RAG agent with a single context prompt and a dynamic tool source
|
|
let calculator_rag = openai_client
|
|
.agent(anthropic::CLAUDE_3_5_SONNET)
|
|
.preamble(
|
|
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
|
Follow these instructions closely.
|
|
1. Consider the user's request carefully and identify the core elements of the request.
|
|
2. Select which tool among those made available to you is appropriate given the context.
|
|
3. This is very important: never perform the operation yourself.
|
|
"
|
|
)
|
|
.tool(Add)
|
|
.tool(Subtract)
|
|
.tool(Multiply)
|
|
.tool(Divide)
|
|
.build();
|
|
|
|
let mut agent = MultiTurnAgent {
|
|
agent: calculator_rag,
|
|
chat_history: Vec::new(),
|
|
};
|
|
|
|
// Prompt the agent and print the response
|
|
let result = agent
|
|
.multi_turn_prompt("Calculate 5 - 2 = ?. Describe the result to me.")
|
|
.await?;
|
|
|
|
println!("\n\nOpenAI Calculator Agent: {}", result);
|
|
|
|
// Prompt the agent again and print the response
|
|
let result = agent
|
|
.multi_turn_prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
|
|
.await?;
|
|
|
|
println!("\n\nOpenAI Calculator Agent: {}", result);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OperationArgs {
|
|
x: i32,
|
|
y: i32,
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
#[error("Math error")]
|
|
struct MathError;
|
|
|
|
#[derive(Deserialize, Serialize)]
|
|
struct Add;
|
|
impl Tool for Add {
|
|
const NAME: &'static str = "add";
|
|
|
|
type Error = MathError;
|
|
type Args = OperationArgs;
|
|
type Output = i32;
|
|
|
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
|
serde_json::from_value(json!({
|
|
"name": "add",
|
|
"description": "Add x and y together",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"x": {
|
|
"type": "number",
|
|
"description": "The first number to add"
|
|
},
|
|
"y": {
|
|
"type": "number",
|
|
"description": "The second number to add"
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.expect("Tool Definition")
|
|
}
|
|
|
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
|
let result = args.x + args.y;
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize)]
|
|
struct Subtract;
|
|
impl Tool for Subtract {
|
|
const NAME: &'static str = "subtract";
|
|
|
|
type Error = MathError;
|
|
type Args = OperationArgs;
|
|
type Output = i32;
|
|
|
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
|
serde_json::from_value(json!({
|
|
"name": "subtract",
|
|
"description": "Subtract y from x (i.e.: x - y)",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"x": {
|
|
"type": "number",
|
|
"description": "The number to subtract from"
|
|
},
|
|
"y": {
|
|
"type": "number",
|
|
"description": "The number to subtract"
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.expect("Tool Definition")
|
|
}
|
|
|
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
|
let result = args.x - args.y;
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
struct Multiply;
|
|
impl Tool for Multiply {
|
|
const NAME: &'static str = "multiply";
|
|
|
|
type Error = MathError;
|
|
type Args = OperationArgs;
|
|
type Output = i32;
|
|
|
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
|
serde_json::from_value(json!({
|
|
"name": "multiply",
|
|
"description": "Compute the product of x and y (i.e.: x * y)",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"x": {
|
|
"type": "number",
|
|
"description": "The first factor in the product"
|
|
},
|
|
"y": {
|
|
"type": "number",
|
|
"description": "The second factor in the product"
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.expect("Tool Definition")
|
|
}
|
|
|
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
|
let result = args.x * args.y;
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
struct Divide;
|
|
impl Tool for Divide {
|
|
const NAME: &'static str = "divide";
|
|
|
|
type Error = MathError;
|
|
type Args = OperationArgs;
|
|
type Output = i32;
|
|
|
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
|
serde_json::from_value(json!({
|
|
"name": "divide",
|
|
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"x": {
|
|
"type": "number",
|
|
"description": "The Dividend of the division. The number being divided"
|
|
},
|
|
"y": {
|
|
"type": "number",
|
|
"description": "The Divisor of the division. The number by which the dividend is being divided"
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.expect("Tool Definition")
|
|
}
|
|
|
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
|
let result = args.x / args.y;
|
|
Ok(result)
|
|
}
|
|
}
|