mirror of https://github.com/0xplaygrounds/rig
Merge 5f042fc84b
into 1dd15d8f8c
This commit is contained in:
commit
aeb19a18bd
|
@ -8622,7 +8622,7 @@ checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rig-bedrock"
|
name = "rig-bedrock"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
|
@ -8633,6 +8633,7 @@ dependencies = [
|
||||||
"reqwest 0.12.15",
|
"reqwest 0.12.15",
|
||||||
"rig-core",
|
"rig-core",
|
||||||
"rig-derive",
|
"rig-derive",
|
||||||
|
"ring 0.17.14",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "rig-bedrock"
|
name = "rig-bedrock"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -18,6 +18,7 @@ aws-sdk-bedrockruntime = "1.77.0"
|
||||||
aws-smithy-types = "1.3.0"
|
aws-smithy-types = "1.3.0"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
async-stream = "0.3.6"
|
async-stream = "0.3.6"
|
||||||
|
ring = "0.17.14"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1.0.75"
|
anyhow = "1.0.75"
|
||||||
|
|
|
@ -6,6 +6,7 @@ use rig_bedrock::{
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
use common::adder_tool::Adder;
|
||||||
|
|
||||||
/// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example)
|
/// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example)
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -59,7 +60,7 @@ async fn tools() -> Result<(), anyhow::Error> {
|
||||||
.await
|
.await
|
||||||
.preamble("You must only do math by using a tool.")
|
.preamble("You must only do math by using a tool.")
|
||||||
.max_tokens(1024)
|
.max_tokens(1024)
|
||||||
.tool(common::Adder)
|
.tool(Adder)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
use rig::{completion::ToolDefinition, tool::Tool};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::{
|
||||||
|
error::Error,
|
||||||
|
fmt::{Display, Formatter},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct MathError {}
|
||||||
|
|
||||||
|
impl Display for MathError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Math error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for MathError {}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
pub struct Adder;
|
||||||
|
impl Tool for Adder {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: "add".to_string(),
|
||||||
|
description: "Add x and y together".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,104 @@
|
||||||
|
use rig::{completion::ToolDefinition, tool::Tool};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::{
|
||||||
|
error::Error,
|
||||||
|
fmt::{Display, Formatter},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AddressBookError(String);
|
||||||
|
|
||||||
|
impl Display for AddressBookError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Address Book error {}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for AddressBookError {}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub struct AddressBook {
|
||||||
|
street_name: String,
|
||||||
|
city: String,
|
||||||
|
state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum AddressBookResult {
|
||||||
|
Found(AddressBook),
|
||||||
|
NotFound(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct AddressBookArgs {
|
||||||
|
email: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
pub struct AddressBookTool;
|
||||||
|
impl Tool for AddressBookTool {
|
||||||
|
const NAME: &'static str = "address_book";
|
||||||
|
|
||||||
|
type Error = AddressBookError;
|
||||||
|
type Args = AddressBookArgs;
|
||||||
|
type Output = AddressBookResult;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: "address_book".to_string(),
|
||||||
|
description: "get address by email".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"email": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "email address"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let mut address_book: HashMap<String, AddressBook> = HashMap::new();
|
||||||
|
address_book.extend(vec![
|
||||||
|
(
|
||||||
|
"john.doe@example.com".to_string(),
|
||||||
|
AddressBook {
|
||||||
|
street_name: "123 Elm St".to_string(),
|
||||||
|
city: "Springfield".to_string(),
|
||||||
|
state: "IL".to_string(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"jane.smith@example.com".to_string(),
|
||||||
|
AddressBook {
|
||||||
|
street_name: "456 Oak St".to_string(),
|
||||||
|
city: "Metropolis".to_string(),
|
||||||
|
state: "NY".to_string(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"alice.johnson@example.com".to_string(),
|
||||||
|
AddressBook {
|
||||||
|
street_name: "789 Pine St".to_string(),
|
||||||
|
city: "Gotham".to_string(),
|
||||||
|
state: "NJ".to_string(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
if args.email.starts_with("malice") {
|
||||||
|
return Err(AddressBookError("Corrupted database".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match address_book.get(&args.email) {
|
||||||
|
Some(address) => Ok(AddressBookResult::Found(address.clone())),
|
||||||
|
None => Ok(AddressBookResult::NotFound("Address not found".into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,60 +1,2 @@
|
||||||
use std::{
|
pub mod adder_tool;
|
||||||
error::Error,
|
pub mod address_book_tool;
|
||||||
fmt::{Display, Formatter},
|
|
||||||
};
|
|
||||||
|
|
||||||
use rig::{completion::ToolDefinition, tool::Tool};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct OperationArgs {
|
|
||||||
x: i32,
|
|
||||||
y: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub struct MathError {}
|
|
||||||
|
|
||||||
impl Display for MathError {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "Math error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for MathError {}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
|
||||||
pub struct Adder;
|
|
||||||
impl Tool for Adder {
|
|
||||||
const NAME: &'static str = "add";
|
|
||||||
|
|
||||||
type Error = MathError;
|
|
||||||
type Args = OperationArgs;
|
|
||||||
type Output = i32;
|
|
||||||
|
|
||||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
|
||||||
ToolDefinition {
|
|
||||||
name: "add".to_string(),
|
|
||||||
description: "Add x and y together".to_string(),
|
|
||||||
parameters: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"x": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "The first number to add"
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "The second number to add"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
|
||||||
let result = args.x + args.y;
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||||
|
mod common;
|
||||||
|
use common::address_book_tool::AddressBookTool;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
// Create agent with a single context prompt and two tools
|
||||||
|
let agent = ClientBuilder::new()
|
||||||
|
.build()
|
||||||
|
.await
|
||||||
|
.agent(AMAZON_NOVA_LITE)
|
||||||
|
.preamble("You have access to user address tool. Never return <thinking> part")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(AddressBookTool)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let result = agent
|
||||||
|
.prompt("Can you find address for this email: jane.smith@example.com")
|
||||||
|
.multi_turn(20)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("\n{}", result);
|
||||||
|
|
||||||
|
let result = agent
|
||||||
|
.prompt("Can you find address for this email: does_not_exists@example.com")
|
||||||
|
.multi_turn(20)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("\n{}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||||
mod common;
|
mod common;
|
||||||
|
use common::adder_tool::Adder;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), anyhow::Error> {
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
@ -17,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
like 20 words",
|
like 20 words",
|
||||||
)
|
)
|
||||||
.max_tokens(1024)
|
.max_tokens(1024)
|
||||||
.tool(common::Adder)
|
.tool(Adder)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
println!("Calculate 2 + 5");
|
println!("Calculate 2 + 5");
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||||
|
|
||||||
use rig::{
|
use rig::{
|
||||||
completion::CompletionError,
|
completion::CompletionError,
|
||||||
message::{ContentFormat, Document},
|
message::{ContentFormat, Document},
|
||||||
|
@ -7,35 +6,55 @@ use rig::{
|
||||||
|
|
||||||
pub(crate) use crate::types::media_types::RigDocumentMediaType;
|
pub(crate) use crate::types::media_types::RigDocumentMediaType;
|
||||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||||
|
use ring::digest::{Context, Digest, SHA256};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RigDocument(pub Document);
|
pub struct RigDocument(pub Document);
|
||||||
|
|
||||||
|
impl RigDocument {
|
||||||
|
pub fn fingerprint(&self) -> String {
|
||||||
|
let mut context = Context::new(&SHA256);
|
||||||
|
context.update(self.0.data.as_bytes());
|
||||||
|
let digest: Digest = context.finish();
|
||||||
|
digest
|
||||||
|
.as_ref()
|
||||||
|
.to_vec()
|
||||||
|
.iter()
|
||||||
|
.map(|b| format!("{:02x}", b))
|
||||||
|
.collect::<String>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
|
impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
|
||||||
type Error = CompletionError;
|
type Error = CompletionError;
|
||||||
|
|
||||||
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
|
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
|
||||||
let maybe_format = value
|
let document_name = value.fingerprint();
|
||||||
|
let document_media_type = value
|
||||||
.0
|
.0
|
||||||
.media_type
|
.media_type
|
||||||
.map(|doc| RigDocumentMediaType(doc).try_into());
|
.map(|doc| RigDocumentMediaType(doc).try_into());
|
||||||
|
|
||||||
let format = match maybe_format {
|
let document_media_type = match document_media_type {
|
||||||
Some(Ok(document_format)) => Ok(Some(document_format)),
|
Some(Ok(document_format)) => Ok(Some(document_format)),
|
||||||
Some(Err(err)) => Err(err),
|
Some(Err(err)) => Err(err),
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let document_data = BASE64_STANDARD
|
let document_data = match value.0.format {
|
||||||
.decode(value.0.data)
|
Some(ContentFormat::Base64) => BASE64_STANDARD
|
||||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
.decode(value.0.data)
|
||||||
|
.map_err(|e| CompletionError::ProviderError(e.to_string()))?,
|
||||||
|
_ => value.0.data.as_bytes().to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
let data = aws_smithy_types::Blob::new(document_data);
|
let data = aws_smithy_types::Blob::new(document_data);
|
||||||
let document_source = aws_bedrock::DocumentSource::Bytes(data);
|
let document_source = aws_bedrock::DocumentSource::Bytes(data);
|
||||||
|
|
||||||
let result = aws_bedrock::DocumentBlock::builder()
|
let result = aws_bedrock::DocumentBlock::builder()
|
||||||
.source(document_source)
|
.source(document_source)
|
||||||
.name("Document")
|
.name(document_name)
|
||||||
.set_format(format)
|
.set_format(document_media_type)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
@ -82,13 +101,32 @@ mod tests {
|
||||||
fn test_document_to_aws_document() {
|
fn test_document_to_aws_document() {
|
||||||
let rig_document = RigDocument(Document {
|
let rig_document = RigDocument(Document {
|
||||||
data: "data".into(),
|
data: "data".into(),
|
||||||
format: Some(ContentFormat::Base64),
|
format: Some(ContentFormat::String),
|
||||||
media_type: Some(DocumentMediaType::PDF),
|
media_type: Some(DocumentMediaType::PDF),
|
||||||
});
|
});
|
||||||
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
||||||
assert_eq!(aws_document.is_ok(), true);
|
assert_eq!(aws_document.is_ok(), true);
|
||||||
let aws_document = aws_document.unwrap();
|
let aws_document = aws_document.unwrap();
|
||||||
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
|
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
|
||||||
|
let document_data = rig_document.0.data.as_bytes().to_vec();
|
||||||
|
let aws_document_bytes = aws_document
|
||||||
|
.source()
|
||||||
|
.unwrap()
|
||||||
|
.as_bytes()
|
||||||
|
.unwrap()
|
||||||
|
.as_ref()
|
||||||
|
.to_owned();
|
||||||
|
assert_eq!(aws_document_bytes, document_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_base64_document_to_aws_document() {
|
||||||
|
let rig_document = RigDocument(Document {
|
||||||
|
data: "data".into(),
|
||||||
|
format: Some(ContentFormat::Base64),
|
||||||
|
media_type: Some(DocumentMediaType::PDF),
|
||||||
|
});
|
||||||
|
let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap();
|
||||||
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
|
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
|
||||||
let aws_document_bytes = aws_document
|
let aws_document_bytes = aws_document
|
||||||
.source()
|
.source()
|
||||||
|
@ -104,7 +142,7 @@ mod tests {
|
||||||
fn test_unsupported_document_to_aws_document() {
|
fn test_unsupported_document_to_aws_document() {
|
||||||
let rig_document = RigDocument(Document {
|
let rig_document = RigDocument(Document {
|
||||||
data: "data".into(),
|
data: "data".into(),
|
||||||
format: Some(ContentFormat::Base64),
|
format: Some(ContentFormat::String),
|
||||||
media_type: Some(DocumentMediaType::Javascript),
|
media_type: Some(DocumentMediaType::Javascript),
|
||||||
});
|
});
|
||||||
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
||||||
|
|
Loading…
Reference in New Issue