mirror of https://github.com/0xplaygrounds/rig
feat: added multi-turn example for AWS Bedrock
This commit is contained in:
parent
1778597814
commit
5f042fc84b
|
@ -6,6 +6,7 @@ use rig_bedrock::{
|
|||
use tracing::info;
|
||||
|
||||
mod common;
|
||||
use common::adder_tool::Adder;
|
||||
|
||||
/// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example)
|
||||
#[tokio::main]
|
||||
|
@ -59,7 +60,7 @@ async fn tools() -> Result<(), anyhow::Error> {
|
|||
.await
|
||||
.preamble("You must only do math by using a tool.")
|
||||
.max_tokens(1024)
|
||||
.tool(common::Adder)
|
||||
.tool(Adder)
|
||||
.build();
|
||||
|
||||
info!(
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
use rig::{completion::ToolDefinition, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt::{Display, Formatter},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MathError {}
|
||||
|
||||
impl Display for MathError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Math error")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MathError {}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
use rig::{completion::ToolDefinition, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt::{Display, Formatter},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AddressBookError(String);
|
||||
|
||||
impl Display for AddressBookError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Address Book error {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for AddressBookError {}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct AddressBook {
|
||||
street_name: String,
|
||||
city: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(untagged)]
|
||||
pub enum AddressBookResult {
|
||||
Found(AddressBook),
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AddressBookArgs {
|
||||
email: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct AddressBookTool;
|
||||
impl Tool for AddressBookTool {
|
||||
const NAME: &'static str = "address_book";
|
||||
|
||||
type Error = AddressBookError;
|
||||
type Args = AddressBookArgs;
|
||||
type Output = AddressBookResult;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "address_book".to_string(),
|
||||
description: "get address by email".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"description": "email address"
|
||||
},
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let mut address_book: HashMap<String, AddressBook> = HashMap::new();
|
||||
address_book.extend(vec![
|
||||
(
|
||||
"john.doe@example.com".to_string(),
|
||||
AddressBook {
|
||||
street_name: "123 Elm St".to_string(),
|
||||
city: "Springfield".to_string(),
|
||||
state: "IL".to_string(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"jane.smith@example.com".to_string(),
|
||||
AddressBook {
|
||||
street_name: "456 Oak St".to_string(),
|
||||
city: "Metropolis".to_string(),
|
||||
state: "NY".to_string(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"alice.johnson@example.com".to_string(),
|
||||
AddressBook {
|
||||
street_name: "789 Pine St".to_string(),
|
||||
city: "Gotham".to_string(),
|
||||
state: "NJ".to_string(),
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
if args.email.starts_with("malice") {
|
||||
return Err(AddressBookError("Corrupted database".into()));
|
||||
}
|
||||
|
||||
match address_book.get(&args.email) {
|
||||
Some(address) => Ok(AddressBookResult::Found(address.clone())),
|
||||
None => Ok(AddressBookResult::NotFound("Address not found".into())),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,60 +1,2 @@
|
|||
use std::{
|
||||
error::Error,
|
||||
fmt::{Display, Formatter},
|
||||
};
|
||||
|
||||
use rig::{completion::ToolDefinition, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MathError {}
|
||||
|
||||
impl Display for MathError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Math error")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MathError {}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
pub mod adder_tool;
|
||||
pub mod address_book_tool;
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
use rig::completion::Prompt;
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
mod common;
|
||||
use common::address_book_tool::AddressBookTool;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt().init();
|
||||
// Create agent with a single context prompt and two tools
|
||||
let agent = ClientBuilder::new()
|
||||
.build()
|
||||
.await
|
||||
.agent(AMAZON_NOVA_LITE)
|
||||
.preamble("You have access to user address tool. Never return <thinking> part")
|
||||
.max_tokens(1024)
|
||||
.tool(AddressBookTool)
|
||||
.build();
|
||||
|
||||
let result = agent
|
||||
.prompt("Can you find address for this email: jane.smith@example.com")
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
println!("\n{}", result);
|
||||
|
||||
let result = agent
|
||||
.prompt("Can you find address for this email: does_not_exists@example.com")
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
println!("\n{}", result);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
mod common;
|
||||
use common::adder_tool::Adder;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
|
@ -17,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
like 20 words",
|
||||
)
|
||||
.max_tokens(1024)
|
||||
.tool(common::Adder)
|
||||
.tool(Adder)
|
||||
.build();
|
||||
|
||||
println!("Calculate 2 + 5");
|
||||
|
|
Loading…
Reference in New Issue