From 7915a85c60a17c7427dddd6f1d61bd6c4eeb6694 Mon Sep 17 00:00:00 2001 From: Marko Kranjac Date: Thu, 17 Apr 2025 19:25:02 +0200 Subject: [PATCH 1/3] fix: AWS Bedrock documents feat: Bedrock has requirement that each document needs unique name so I added fingerprint based on document content fix: decode document content based on ContentFormat variant --- Cargo.lock | 1 + rig-bedrock/Cargo.toml | 1 + rig-bedrock/src/types/document.rs | 58 +++++++++++++++++++++++++------ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5714d75..718872e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8633,6 +8633,7 @@ dependencies = [ "reqwest 0.12.15", "rig-core", "rig-derive", + "ring 0.17.14", "schemars", "serde", "serde_json", diff --git a/rig-bedrock/Cargo.toml b/rig-bedrock/Cargo.toml index bb204b5..473e63d 100644 --- a/rig-bedrock/Cargo.toml +++ b/rig-bedrock/Cargo.toml @@ -18,6 +18,7 @@ aws-sdk-bedrockruntime = "1.77.0" aws-smithy-types = "1.3.0" base64 = "0.22.1" async-stream = "0.3.6" +ring = "0.17.14" [dev-dependencies] anyhow = "1.0.75" diff --git a/rig-bedrock/src/types/document.rs b/rig-bedrock/src/types/document.rs index daf0bd3..ab73f26 100644 --- a/rig-bedrock/src/types/document.rs +++ b/rig-bedrock/src/types/document.rs @@ -1,5 +1,4 @@ use aws_sdk_bedrockruntime::types as aws_bedrock; - use rig::{ completion::CompletionError, message::{ContentFormat, Document}, @@ -7,35 +6,55 @@ use rig::{ pub(crate) use crate::types::media_types::RigDocumentMediaType; use base64::{prelude::BASE64_STANDARD, Engine}; +use ring::digest::{Context, Digest, SHA256}; #[derive(Clone)] pub struct RigDocument(pub Document); +impl RigDocument { + pub fn fingerprint(&self) -> String { + let mut context = Context::new(&SHA256); + context.update(self.0.data.as_bytes()); + let digest: Digest = context.finish(); + digest + .as_ref() + .to_vec() + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() + } +} + impl TryFrom for aws_bedrock::DocumentBlock { type Error = CompletionError; fn try_from(value: RigDocument) -> Result { - let maybe_format = value + let document_name = value.fingerprint(); + let document_media_type = value .0 .media_type .map(|doc| RigDocumentMediaType(doc).try_into()); - let format = match maybe_format { + let document_media_type = match document_media_type { Some(Ok(document_format)) => Ok(Some(document_format)), Some(Err(err)) => Err(err), None => Ok(None), }?; - let document_data = BASE64_STANDARD - .decode(value.0.data) - .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let document_data = match value.0.format { + Some(ContentFormat::Base64) => BASE64_STANDARD + .decode(value.0.data) + .map_err(|e| CompletionError::ProviderError(e.to_string()))?, + _ => value.0.data.as_bytes().to_vec(), + }; + let data = aws_smithy_types::Blob::new(document_data); let document_source = aws_bedrock::DocumentSource::Bytes(data); let result = aws_bedrock::DocumentBlock::builder() .source(document_source) - .name("Document") - .set_format(format) + .name(document_name) + .set_format(document_media_type) .build() .map_err(|e| CompletionError::ProviderError(e.to_string()))?; Ok(result) @@ -82,13 +101,32 @@ mod tests { fn test_document_to_aws_document() { let rig_document = RigDocument(Document { data: "data".into(), - format: Some(ContentFormat::Base64), + format: Some(ContentFormat::String), media_type: Some(DocumentMediaType::PDF), }); let aws_document: Result = rig_document.clone().try_into(); assert_eq!(aws_document.is_ok(), true); let aws_document = aws_document.unwrap(); assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf); + let document_data = rig_document.0.data.as_bytes().to_vec(); + let aws_document_bytes = aws_document + .source() + .unwrap() + .as_bytes() + .unwrap() + .as_ref() + .to_owned(); + assert_eq!(aws_document_bytes, document_data) + } + + #[test] + fn test_base64_document_to_aws_document() { + let rig_document = RigDocument(Document { + data: "data".into(), + format: Some(ContentFormat::Base64), + media_type: Some(DocumentMediaType::PDF), + }); + let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap(); let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap(); let aws_document_bytes = aws_document .source() @@ -104,7 +142,7 @@ mod tests { fn test_unsupported_document_to_aws_document() { let rig_document = RigDocument(Document { data: "data".into(), - format: Some(ContentFormat::Base64), + format: Some(ContentFormat::String), media_type: Some(DocumentMediaType::Javascript), }); let aws_document: Result = rig_document.clone().try_into(); From 177859781451769a4f2ba56e66b17ce0b0f35e16 Mon Sep 17 00:00:00 2001 From: Marko Kranjac Date: Fri, 18 Apr 2025 01:20:18 +0200 Subject: [PATCH 2/3] feat: AWS Bedrock document changes feat: Bedrock has requirement that each document needs unique name so I added fingerprint based on document content bugfix: decode document content based on ContentFormat variant feat: use normalized_documents instead of prompt_with_context --- Cargo.lock | 2 +- rig-bedrock/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 718872e..7396992 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8622,7 +8622,7 @@ checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" [[package]] name = "rig-bedrock" -version = "0.1.1" +version = "0.1.2" dependencies = [ "anyhow", "async-stream", diff --git a/rig-bedrock/Cargo.toml b/rig-bedrock/Cargo.toml index 473e63d..99af363 100644 --- a/rig-bedrock/Cargo.toml +++ b/rig-bedrock/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-bedrock" -version = "0.1.1" +version = "0.1.2" edition = "2021" license = "MIT" readme = "README.md" From 5f042fc84b81b4311e9abf88e2cd85d43494f5ec Mon Sep 17 00:00:00 2001 From: Marko Kranjac Date: Fri, 18 Apr 2025 13:54:52 +0200 Subject: [PATCH 3/3] feat: added multi-turn example for AWS Bedrock --- rig-bedrock/examples/agent_with_bedrock.rs | 3 +- rig-bedrock/examples/common/adder_tool.rs | 59 ++++++++++ .../examples/common/address_book_tool.rs | 104 ++++++++++++++++++ rig-bedrock/examples/common/mod.rs | 62 +---------- rig-bedrock/examples/multi_turn_bedrock.rs | 34 ++++++ .../streaming_with_bedrock_and_tools.rs | 3 +- 6 files changed, 203 insertions(+), 62 deletions(-) create mode 100644 rig-bedrock/examples/common/adder_tool.rs create mode 100644 rig-bedrock/examples/common/address_book_tool.rs create mode 100644 rig-bedrock/examples/multi_turn_bedrock.rs diff --git a/rig-bedrock/examples/agent_with_bedrock.rs b/rig-bedrock/examples/agent_with_bedrock.rs index b6eb9ba..2a45103 100644 --- a/rig-bedrock/examples/agent_with_bedrock.rs +++ b/rig-bedrock/examples/agent_with_bedrock.rs @@ -6,6 +6,7 @@ use rig_bedrock::{ use tracing::info; mod common; +use common::adder_tool::Adder; /// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example) #[tokio::main] @@ -59,7 +60,7 @@ async fn tools() -> Result<(), anyhow::Error> { .await .preamble("You must only do math by using a tool.") .max_tokens(1024) - .tool(common::Adder) + .tool(Adder) .build(); info!( diff --git a/rig-bedrock/examples/common/adder_tool.rs b/rig-bedrock/examples/common/adder_tool.rs new file mode 100644 index 0000000..0a104d3 --- /dev/null +++ b/rig-bedrock/examples/common/adder_tool.rs @@ -0,0 +1,59 @@ +use rig::{completion::ToolDefinition, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; + +#[derive(Deserialize)] +pub struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct MathError {} + +impl Display for MathError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Math error") + } +} + +impl Error for MathError {} + +#[derive(Deserialize, Serialize)] +pub struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} diff --git a/rig-bedrock/examples/common/address_book_tool.rs b/rig-bedrock/examples/common/address_book_tool.rs new file mode 100644 index 0000000..36dbd4d --- /dev/null +++ b/rig-bedrock/examples/common/address_book_tool.rs @@ -0,0 +1,104 @@ +use rig::{completion::ToolDefinition, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; + +#[derive(Debug)] +pub struct AddressBookError(String); + +impl Display for AddressBookError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Address Book error {}", self.0) + } +} + +impl Error for AddressBookError {} + +#[derive(Serialize, Clone)] +pub struct AddressBook { + street_name: String, + city: String, + state: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +#[serde(untagged)] +pub enum AddressBookResult { + Found(AddressBook), + NotFound(String), +} + +#[derive(Deserialize)] +pub struct AddressBookArgs { + email: String, +} + +#[derive(Deserialize, Serialize)] +pub struct AddressBookTool; +impl Tool for AddressBookTool { + const NAME: &'static str = "address_book"; + + type Error = AddressBookError; + type Args = AddressBookArgs; + type Output = AddressBookResult; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "address_book".to_string(), + description: "get address by email".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "email": { + "type": "string", + "description": "email address" + }, + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let mut address_book: HashMap = HashMap::new(); + address_book.extend(vec![ + ( + "john.doe@example.com".to_string(), + AddressBook { + street_name: "123 Elm St".to_string(), + city: "Springfield".to_string(), + state: "IL".to_string(), + }, + ), + ( + "jane.smith@example.com".to_string(), + AddressBook { + street_name: "456 Oak St".to_string(), + city: "Metropolis".to_string(), + state: "NY".to_string(), + }, + ), + ( + "alice.johnson@example.com".to_string(), + AddressBook { + street_name: "789 Pine St".to_string(), + city: "Gotham".to_string(), + state: "NJ".to_string(), + }, + ), + ]); + + if args.email.starts_with("malice") { + return Err(AddressBookError("Corrupted database".into())); + } + + match address_book.get(&args.email) { + Some(address) => Ok(AddressBookResult::Found(address.clone())), + None => Ok(AddressBookResult::NotFound("Address not found".into())), + } + } +} diff --git a/rig-bedrock/examples/common/mod.rs b/rig-bedrock/examples/common/mod.rs index 6cc9468..391bfad 100644 --- a/rig-bedrock/examples/common/mod.rs +++ b/rig-bedrock/examples/common/mod.rs @@ -1,60 +1,2 @@ -use std::{ - error::Error, - fmt::{Display, Formatter}, -}; - -use rig::{completion::ToolDefinition, tool::Tool}; -use serde::{Deserialize, Serialize}; -use serde_json::json; - -#[derive(Deserialize)] -pub struct OperationArgs { - x: i32, - y: i32, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct MathError {} - -impl Display for MathError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Math error") - } -} - -impl Error for MathError {} - -#[derive(Deserialize, Serialize)] -pub struct Adder; -impl Tool for Adder { - const NAME: &'static str = "add"; - - type Error = MathError; - type Args = OperationArgs; - type Output = i32; - - async fn definition(&self, _prompt: String) -> ToolDefinition { - ToolDefinition { - name: "add".to_string(), - description: "Add x and y together".to_string(), - parameters: json!({ - "type": "object", - "properties": { - "x": { - "type": "number", - "description": "The first number to add" - }, - "y": { - "type": "number", - "description": "The second number to add" - } - } - }), - } - } - - async fn call(&self, args: Self::Args) -> Result { - let result = args.x + args.y; - Ok(result) - } -} +pub mod adder_tool; +pub mod address_book_tool; diff --git a/rig-bedrock/examples/multi_turn_bedrock.rs b/rig-bedrock/examples/multi_turn_bedrock.rs new file mode 100644 index 0000000..c6c41d8 --- /dev/null +++ b/rig-bedrock/examples/multi_turn_bedrock.rs @@ -0,0 +1,34 @@ +use rig::completion::Prompt; +use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE}; +mod common; +use common::address_book_tool::AddressBookTool; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let agent = ClientBuilder::new() + .build() + .await + .agent(AMAZON_NOVA_LITE) + .preamble("You have access to user address tool. Never return part") + .max_tokens(1024) + .tool(AddressBookTool) + .build(); + + let result = agent + .prompt("Can you find address for this email: jane.smith@example.com") + .multi_turn(20) + .await?; + + println!("\n{}", result); + + let result = agent + .prompt("Can you find address for this email: does_not_exists@example.com") + .multi_turn(20) + .await?; + + println!("\n{}", result); + + Ok(()) +} diff --git a/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs b/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs index e18865f..2c2b739 100644 --- a/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs +++ b/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs @@ -1,6 +1,7 @@ use rig::streaming::{stream_to_stdout, StreamingPrompt}; use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE}; mod common; +use common::adder_tool::Adder; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -17,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> { like 20 words", ) .max_tokens(1024) - .tool(common::Adder) + .tool(Adder) .build(); println!("Calculate 2 + 5");