Merge branch 'main' of https://github.com/0xPlaygrounds/rig into fix/multiple-tool-calling

This commit is contained in:
0xMochan 2025-04-07 16:41:14 -07:00
commit d618d1a435
38 changed files with 2541 additions and 15 deletions

96
Cargo.lock generated
View File

@ -1134,6 +1134,30 @@ dependencies = [
"uuid 1.16.0",
]
[[package]]
name = "aws-sdk-bedrockruntime"
version = "1.77.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4198493316dab97e1fed7716f3823462b73a34c518f4ee7b9799921645e232e5"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-eventstream",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-dynamodb"
version = "1.69.0"
@ -1258,12 +1282,24 @@ dependencies = [
"tokio",
]
[[package]]
name = "aws-smithy-eventstream"
version = "0.60.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c45d3dddac16c5c59d553ece225a88870cf81b7b813c9cc17b78cf4685eac7a"
dependencies = [
"aws-smithy-types",
"bytes",
"crc32fast",
]
[[package]]
name = "aws-smithy-http"
version = "0.62.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166"
dependencies = [
"aws-smithy-eventstream",
"aws-smithy-runtime-api",
"aws-smithy-types",
"bytes",
@ -8584,6 +8620,27 @@ version = "0.8.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
[[package]]
name = "rig-bedrock"
version = "0.1.0"
dependencies = [
"anyhow",
"async-stream",
"aws-config",
"aws-sdk-bedrockruntime",
"aws-smithy-types",
"base64 0.22.1",
"reqwest 0.12.15",
"rig-core 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rig-derive",
"schemars",
"serde",
"serde_json",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "rig-core"
version = "0.11.0"
@ -8617,6 +8674,27 @@ dependencies = [
"worker",
]
[[package]]
name = "rig-core"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff893305131b471009ab11df388612beb603ed94bb12412c256fe197b7591aa6"
dependencies = [
"async-stream",
"base64 0.22.1",
"bytes",
"futures",
"glob",
"mime_guess",
"ordered-float",
"reqwest 0.12.15",
"schemars",
"serde",
"serde_json",
"thiserror 1.0.69",
"tracing",
]
[[package]]
name = "rig-derive"
version = "0.1.0"
@ -8634,7 +8712,7 @@ dependencies = [
"anyhow",
"ethers",
"reqwest 0.12.15",
"rig-core",
"rig-core 0.11.0",
"schemars",
"serde",
"serde_json",
@ -8649,7 +8727,7 @@ version = "0.1.4"
dependencies = [
"anyhow",
"fastembed",
"rig-core",
"rig-core 0.11.0",
"schemars",
"serde",
"serde_json",
@ -8667,7 +8745,7 @@ dependencies = [
"futures",
"httpmock",
"lancedb",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"tokio",
@ -8681,7 +8759,7 @@ dependencies = [
"futures",
"httpmock",
"mongodb",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"testcontainers",
@ -8698,7 +8776,7 @@ dependencies = [
"futures",
"httpmock",
"neo4rs",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"term_size",
@ -8718,7 +8796,7 @@ dependencies = [
"httpmock",
"log",
"pgvector",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"sqlx",
@ -8737,7 +8815,7 @@ dependencies = [
"anyhow",
"httpmock",
"qdrant-client",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"testcontainers",
@ -8752,7 +8830,7 @@ dependencies = [
"anyhow",
"chrono",
"httpmock",
"rig-core",
"rig-core 0.11.0",
"rusqlite",
"serde",
"serde_json",
@ -8769,7 +8847,7 @@ name = "rig-surrealdb"
version = "0.1.3"
dependencies = [
"anyhow",
"rig-core",
"rig-core 0.11.0",
"serde",
"serde_json",
"surrealdb",

View File

@ -9,6 +9,8 @@ members = [
"rig-qdrant",
"rig-core/rig-core-derive",
"rig-sqlite",
"rig-eternalai", "rig-fastembed",
"rig-surrealdb",
"rig-eternalai",
"rig-fastembed",
"rig-bedrock",
]

26
rig-bedrock/Cargo.toml Normal file
View File

@ -0,0 +1,26 @@
[package]
name = "rig-bedrock"
version = "0.1.0"
edition = "2021"
license = "MIT"
readme = "README.md"
description = "AWS Bedrock model provider for Rig integration."
[dependencies]
rig-core = { version = "0.11.0", features = ["image"] }
rig-derive = { path = "../rig-core/rig-core-derive", version = "0.1.0" }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
schemars = "0.8.16"
tracing = "0.1.40"
aws-config = { version = "1.6.0", features = ["behavior-version-latest"] }
aws-sdk-bedrockruntime = "1.77.0"
aws-smithy-types = "1.3.0"
base64 = "0.22.1"
async-stream = "0.3.6"
[dev-dependencies]
anyhow = "1.0.75"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
reqwest = { version = "0.12.12", features = ["json", "stream"] }

23
rig-bedrock/README.md Normal file
View File

@ -0,0 +1,23 @@
## Rig-Bedrock
This companion crate integrates AWS Bedrock as model provider with Rig.
## Usage
Add the companion crate to your `Cargo.toml`, along with the rig-core crate:
```toml
[dependencies]
rig-bedrock = "0.1.0"
rig-core = "0.9.1"
```
You can also run `cargo add rig-bedrock rig-core` to add the most recent versions of the dependencies to your project.
See the [`/examples`](./examples) folder for usage examples.
Make sure to have AWS credentials env vars loaded before starting client such as:
```shell
export AWS_DEFAULT_REGION=us-east-1
export AWS_SECRET_ACCESS_KEY=.......
export AWS_ACCESS_KEY_ID=......
```

View File

@ -0,0 +1,121 @@
use rig::{agent::AgentBuilder, completion::Prompt, loaders::FileLoader};
use rig_bedrock::{
client::{Client, ClientBuilder},
completion::AMAZON_NOVA_LITE,
};
use tracing::info;
mod common;
/// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example)
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
info!("Running basic agent");
basic().await?;
info!("\nRunning agent with tools");
tools().await?;
info!("\nRunning agent with loaders");
loaders().await?;
info!("\nRunning agent with context");
context().await?;
info!("\n\nAll agents ran successfully");
Ok(())
}
async fn client() -> Client {
ClientBuilder::new().build().await
}
async fn partial_agent() -> AgentBuilder<rig_bedrock::completion::CompletionModel> {
let client = client().await;
client.agent(AMAZON_NOVA_LITE)
}
/// Create an AWS Bedrock agent with a system prompt
async fn basic() -> Result<(), anyhow::Error> {
let agent = partial_agent()
.await
.preamble("Answer with json format only")
.build();
let response = agent.prompt("Describe solar system").await?;
info!("{}", response);
Ok(())
}
/// Create an AWS Bedrock with tools
async fn tools() -> Result<(), anyhow::Error> {
let calculator_agent = partial_agent()
.await
.preamble("You must only do math by using a tool.")
.max_tokens(1024)
.tool(common::Adder)
.build();
info!(
"Calculator Agent: add 400 and 20\nResult: {}",
calculator_agent.prompt("add 400 and 20").await?
);
Ok(())
}
async fn context() -> Result<(), anyhow::Error> {
let model = client().await.completion_model(AMAZON_NOVA_LITE);
// Create an agent with multiple context documents
let agent = AgentBuilder::new(model)
.preamble("Answer the question")
.context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build();
// Prompt the agent and print the response
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;
info!("What does \"glarb-glarb\" mean?\n{}", response);
Ok(())
}
/// Based upon the `loaders` example
///
/// This example loads in all the rust examples from the rig-core crate and uses them as\\
/// context for the agent
async fn loaders() -> Result<(), anyhow::Error> {
let model = client().await.completion_model(AMAZON_NOVA_LITE);
// Load in all the rust examples
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
.read_with_path()
.ignore_errors()
.into_iter();
// Create an agent with multiple context documents
let agent = examples
.fold(AgentBuilder::new(model), |builder, (path, content)| {
builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
})
.preamble("Answer the question")
.build();
// Prompt the agent and print the response
let response = agent
.prompt("Which rust example is best suited for the operation 1 + 2")
.await?;
info!("{}", response);
Ok(())
}

View File

@ -0,0 +1,60 @@
use std::{
error::Error,
fmt::{Display, Formatter},
};
use rig::{completion::ToolDefinition, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
pub struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MathError {}
impl Display for MathError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Math error")
}
}
impl Error for MathError {}
#[derive(Deserialize, Serialize)]
pub struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}

View File

@ -0,0 +1,50 @@
use reqwest::Client;
use rig::{
completion::{message::Document, Prompt},
message::{ContentFormat, DocumentMediaType},
};
use base64::{prelude::BASE64_STANDARD, Engine};
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
use tracing::info;
const DOCUMENT_URL: &str =
"https://www.inf.ed.ac.uk/teaching/courses/ai2/module4/small_slides/small-agents.pdf";
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.without_time()
.with_level(false)
.with_target(false)
.init();
let client = ClientBuilder::new().build().await;
let agent = client
.agent(AMAZON_NOVA_LITE)
.preamble("Describe this document but respond with json format only")
.temperature(0.5)
.build();
let reqwest_client = Client::new();
let response = reqwest_client.get(DOCUMENT_URL).send().await?;
info!("Status: {}", response.status().as_str());
info!("Content Type: {:?}", response.headers().get("Content-Type"));
let document_bytes = response.bytes().await?;
let bytes_base64 = BASE64_STANDARD.encode(document_bytes);
let document = Document {
data: bytes_base64,
format: Some(ContentFormat::Base64),
media_type: Some(DocumentMediaType::PDF),
};
let response = agent.prompt(document).await?;
info!("{}", response);
Ok(())
}

View File

@ -0,0 +1,34 @@
use rig::Embed;
use rig_bedrock::{client::ClientBuilder, embedding::AMAZON_TITAN_EMBED_TEXT_V2_0};
use tracing::info;
#[derive(rig_derive::Embed, Debug)]
struct Greetings {
#[embed]
message: String,
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
let client = ClientBuilder::new().build().await;
let embeddings = client
.embeddings(AMAZON_TITAN_EMBED_TEXT_V2_0, 256)
.document(Greetings {
message: "aa".to_string(),
})?
.document(Greetings {
message: "bb".to_string(),
})?
.build()
.await?;
info!("{:?}", embeddings);
Ok(())
}

View File

@ -0,0 +1,32 @@
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tracing::info;
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct Person {
pub first_name: Option<String>,
pub last_name: Option<String>,
pub job: Option<String>,
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
let client = ClientBuilder::new().build().await;
let data_extractor = client.extractor::<Person>(AMAZON_NOVA_LITE).build();
let person = data_extractor
.extract("Hello my name is John Doe! I am a software engineer.")
.await?;
info!(
"AWS Bedrock: {}",
serde_json::to_string_pretty(&person).unwrap()
);
Ok(())
}

View File

@ -0,0 +1,25 @@
use rig::image_generation::ImageGenerationModel;
use rig_bedrock::client::ClientBuilder;
use rig_bedrock::image::AMAZON_NOVA_CANVAS;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const DEFAULT_PATH: &str = "./output.png";
#[tokio::main]
async fn main() {
let client = ClientBuilder::new().build().await;
let image_generation_model = client.image_generation_model(AMAZON_NOVA_CANVAS);
let response = image_generation_model
.image_generation_request()
.prompt("A castle sitting upon a large mountain, overlooking the water.")
.width(512)
.height(512)
.send()
.await;
// save image
let mut file = File::create_new(Path::new(&DEFAULT_PATH)).expect("Failed to create file");
let _ = file.write(&response.unwrap().image);
}

View File

@ -0,0 +1,46 @@
use reqwest::Client;
use rig::{
completion::{message::Image, Prompt},
message::{ContentFormat, ImageMediaType},
};
use base64::{prelude::BASE64_STANDARD, Engine};
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
use tracing::info;
const IMAGE_URL: &str = "https://playgrounds.network/assets/PG-Logo.png";
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
let client = ClientBuilder::new().build().await;
let agent = client
.agent(AMAZON_NOVA_LITE)
.preamble("You are an image describer.")
.temperature(0.5)
.build();
// Grab image and convert to base64
let reqwest_client = Client::new();
let image_bytes = reqwest_client.get(IMAGE_URL).send().await?.bytes().await?;
let image_base64 = BASE64_STANDARD.encode(image_bytes);
// Compose `Image` for prompt
let image = Image {
data: image_base64,
media_type: Some(ImageMediaType::PNG),
format: Some(ContentFormat::Base64),
..Default::default()
};
// Prompt the agent and print the response
let response = agent.prompt(image).await?;
info!("{}", response);
Ok(())
}

View File

@ -0,0 +1,85 @@
use std::vec;
use rig::{
completion::Prompt, embeddings::EmbeddingsBuilder,
vector_store::in_memory_store::InMemoryVectorStore, Embed,
};
use rig_bedrock::{
client::ClientBuilder, completion::AMAZON_NOVA_LITE, embedding::AMAZON_TITAN_EMBED_TEXT_V2_0,
};
use serde::Serialize;
use tracing::info;
// Data to be RAG-ed.
// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition`
// and tag that field with `#[embed]`.
#[derive(rig_derive::Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)]
struct WordDefinition {
id: String,
word: String,
#[embed]
definitions: Vec<String>,
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
let client = ClientBuilder::new().build().await;
let embedding_model = client.embedding_model(AMAZON_TITAN_EMBED_TEXT_V2_0, 256);
// Generate embeddings for the definitions of all the documents using the specified embedding model.
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(vec![
WordDefinition {
id: "doc0".to_string(),
word: "flurbo".to_string(),
definitions: vec![
"1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(),
"2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
},
WordDefinition {
id: "doc1".to_string(),
word: "glarb-glarb".to_string(),
definitions: vec![
"1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
},
WordDefinition {
id: "doc2".to_string(),
word: "linglingdong".to_string(),
definitions: vec![
"1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
"2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
},
])?
.build()
.await?;
// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents(embeddings);
// Create vector store index
let index = vector_store.index(embedding_model);
let rag_agent = client.agent(AMAZON_NOVA_LITE)
.preamble("
You are a dictionary assistant here to assist the user in understanding the meaning of words.
You will find additional non-standard word definitions that could be useful below.
")
.dynamic_context(1, index)
.build();
// Prompt the agent and print the response
let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await?;
info!("{}", response);
Ok(())
}

View File

@ -0,0 +1,23 @@
use rig::streaming::{stream_to_stdout, StreamingPrompt};
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = ClientBuilder::new()
.build()
.await
.agent(AMAZON_NOVA_LITE)
.preamble("Be precise and concise.")
.temperature(0.5)
.build();
// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await?;
stream_to_stdout(agent, &mut stream).await?;
Ok(())
}

View File

@ -0,0 +1,27 @@
use rig::streaming::{stream_to_stdout, StreamingPrompt};
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
mod common;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let agent = ClientBuilder::new()
.build()
.await
.agent(AMAZON_NOVA_LITE)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(common::Adder)
.build();
println!("Calculate 2 + 5");
let mut stream = agent.stream_prompt("Calculate 2 + 5").await?;
stream_to_stdout(agent, &mut stream).await?;
Ok(())
}

86
rig-bedrock/src/client.rs Normal file
View File

@ -0,0 +1,86 @@
use crate::image::ImageGenerationModel;
use crate::{completion::CompletionModel, embedding::EmbeddingModel};
use aws_config::{BehaviorVersion, Region};
use rig::{agent::AgentBuilder, embeddings, extractor::ExtractorBuilder, Embed};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub const DEFAULT_AWS_REGION: &str = "us-east-1";
#[derive(Clone)]
pub struct ClientBuilder<'a> {
region: &'a str,
}
/// Create a new Bedrock client using the builder <br>
impl<'a> ClientBuilder<'a> {
pub fn new() -> Self {
Self {
region: DEFAULT_AWS_REGION,
}
}
/// Make sure to verify model and region [compatibility]
///
/// [compatibility]: https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html
pub fn region(mut self, region: &'a str) -> Self {
self.region = region;
self
}
/// Make sure you have permissions to access [Amazon Bedrock foundation model]
///
/// [ Amazon Bedrock foundation model]: <https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html>
pub async fn build(self) -> Client {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new(String::from(self.region)))
.load()
.await;
let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
Client { aws_client: client }
}
}
impl Default for ClientBuilder<'_> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Client {
pub(crate) aws_client: aws_sdk_bedrockruntime::Client,
}
impl Client {
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn embedding_model(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
pub fn embeddings<D: Embed>(
&self,
model: &str,
ndims: usize,
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model, ndims))
}
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
ImageGenerationModel::new(self.clone(), model)
}
}

View File

@ -0,0 +1,138 @@
//! All supported models <https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html>
use crate::{
client::Client,
types::{
assistant_content::AwsConverseOutput, completion_request::AwsCompletionRequest,
errors::AwsSdkConverseError,
},
};
use rig::completion::{self, CompletionError};
/// `amazon.nova-canvas-v1:0`
pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
/// `amazon.nova-lite-v1:0`
pub const AMAZON_NOVA_LITE: &str = "amazon.nova-lite-v1:0";
/// `amazon.nova-micro-v1:0`
pub const AMAZON_NOVA_MICRO: &str = "amazon.nova-micro-v1:0";
/// `amazon.nova-pro-v1:0`
pub const AMAZON_NOVA_PRO: &str = "amazon.nova-pro-v1:0";
/// `amazon.rerank-v1:0`
pub const AMAZON_RERANK_1_0: &str = "amazon.rerank-v1:0";
/// `amazon.titan-text-express-v1`
pub const AMAZON_TITAN_TEXT_EXPRESS_V1: &str = "amazon.titan-text-express-v1";
/// `amazon.titan-text-lite-v1`
pub const AMAZON_TITAN_TEXT_LITE_V1: &str = "amazon.titan-text-lite-v1";
/// `amazon.titan-text-premier-v1:0`
pub const AMAZON_TITAN_TEXT_PREMIER_V1_0: &str = "amazon.titan-text-premier-v1:0";
/// `anthropic.claude-3-haiku-20240307-v1:0`
pub const ANTHROPIC_CLAUDE_3_HAIKU: &str = "anthropic.claude-3-haiku-20240307-v1:0";
/// `anthropic.claude-3-opus-20240229-v1:0`
pub const ANTHROPIC_CLAUDE_3_OPUS: &str = "anthropic.claude-3-opus-20240229-v1:0";
/// `anthropic.claude-3-sonnet-20240229-v1:0`
pub const ANTHROPIC_CLAUDE_3_SONNET: &str = "anthropic.claude-3-sonnet-20240229-v1:0";
/// `anthropic.claude-3-5-haiku-20241022-v1:0`
pub const ANTHROPIC_CLAUDE_3_5_HAIKU: &str = "anthropic.claude-3-5-haiku-20241022-v1:0";
/// `anthropic.claude-3-5-sonnet-20241022-v2:0`
pub const ANTHROPIC_CLAUDE_3_5_SONNET_V2: &str = "anthropic.claude-3-5-sonnet-20241022-v2:0";
/// `anthropic.claude-3-5-sonnet-20240620-v1:0`
pub const ANTHROPIC_CLAUDE_3_5_SONNET: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
/// `anthropic.claude-3-7-sonnet-20250219-v1:0`
pub const ANTHROPIC_CLAUDE_3_7_SONNET: &str = "anthropic.claude-3-7-sonnet-20250219-v1:0";
/// `cohere.command-light-text-v14`
pub const COHERE_COMMAND_LIGHT_TEXT: &str = "cohere.command-light-text-v14";
/// `cohere.command-r-plus-v1:0`
pub const COHERE_COMMAND_R_PLUS: &str = "cohere.command-r-plus-v1:0";
/// `cohere.command-r-v1:0`
pub const COHERE_COMMAND_R: &str = "cohere.command-r-v1:0";
/// `cohere.command-text-v14`
pub const COHERE_COMMAND: &str = "cohere.command-text-v14";
/// `cohere.rerank-v3-5:0`
pub const COHERE_RERANK_V3_5: &str = "cohere.rerank-v3-5:0";
/// `luma.ray-v2:0`
pub const LUMA_RAY_V2_0: &str = "luma.ray-v2:0";
/// `meta.llama3-8b-instruct-v1:0`
pub const LLAMA_3_8B_INSTRUCT: &str = "meta.llama3-8b-instruct-v1:0";
/// `meta.llama3-70b-instruct-v1:0`
pub const LLAMA_3_70B_INSTRUCT: &str = "meta.llama3-70b-instruct-v1:0";
/// `meta.llama3-1-8b-instruct-v1:0`
pub const LLAMA_3_1_8B_INSTRUCT: &str = "meta.llama3-1-8b-instruct-v1:0";
/// `meta.llama3-1-70b-instruct-v1:0`
pub const LLAMA_3_1_70B_INSTRUCT: &str = "meta.llama3-1-70b-instruct-v1:0";
/// `meta.llama3-1-405b-instruct-v1:0`
pub const LLAMA_3_1_405B_INSTRUCT: &str = "meta.llama3-1-405b-instruct-v1:0";
/// `meta.llama3-2-1b-instruct-v1:0`
pub const LLAMA_3_2_1B_INSTRUCT: &str = "meta.llama3-2-1b-instruct-v1:0";
/// `meta.llama3-2-3b-instruct-v1:0`
pub const LLAMA_3_2_3B_INSTRUCT: &str = "meta.llama3-2-3b-instruct-v1:0";
/// `meta.llama3-2-11b-instruct-v1:0`
pub const LLAMA_3_2_11B_INSTRUCT: &str = "meta.llama3-2-11b-instruct-v1:0";
/// `meta.llama3-2-90b-instruct-v1:0`
pub const LLAMA_3_2_90B_INSTRUCT: &str = "meta.llama3-2-90b-instruct-v1:0";
/// `meta.llama3-3-70b-instruct-v1:0`
pub const LLAMA_3_2_70B_INSTRUCT: &str = "meta.llama3-3-70b-instruct-v1:0";
/// `mistral.mistral-7b-instruct-v0:2`
pub const MISTRAL_7B_INSTRUCT: &str = "mistral.mistral-7b-instruct-v0:2";
/// `mistral.mistral-large-2402-v1:0`
pub const MISTRAL_LARGE_24_02: &str = "mistral.mistral-large-2402-v1:0";
/// `mistral.mistral-large-2407-v1:0`
pub const MISTRAL_LARGE_24_07: &str = "mistral.mistral-large-2407-v1:0";
/// `mistral.mistral-small-2402-v1:0`
pub const MISTRAL_SMALL_24_02: &str = "mistral.mistral-small-2402-v1:0";
/// `mistral.mixtral-8x7b-instruct-v0:1`
pub const MISTRAL_MIXTRAL_8X7B_INSTRUCT_V0: &str = "mistral.mixtral-8x7b-instruct-v0:1";
/// `stability.sd3-5-large-v1:0`
pub const STABILITY_SD3_5_LARGE: &str = "stability.sd3-5-large-v1:0";
/// `ai21.jamba-1-5-large-v1:0`
pub const JAMBA_1_5_LARGE: &str = "ai21.jamba-1-5-large-v1:0";
/// `ai21.jamba-1-5-mini-v1:0`
pub const JAMBA_1_5_MINI: &str = "ai21.jamba-1-5-mini-v1:0";
#[derive(Clone)]
pub struct CompletionModel {
pub(crate) client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = AwsConverseOutput;
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<AwsConverseOutput>, CompletionError> {
let request = AwsCompletionRequest(completion_request);
let mut converse_builder = self
.client
.aws_client
.converse()
.model_id(self.model.as_str());
let tool_config = request.tools_config()?;
let prompt_with_history = request.prompt_with_history()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt())
.set_messages(Some(prompt_with_history));
let response = converse_builder
.send()
.await
.map_err(|sdk_error| Into::<CompletionError>::into(AwsSdkConverseError(sdk_error)))?;
AwsConverseOutput(response).try_into()
}
}

View File

@ -0,0 +1,121 @@
use aws_smithy_types::Blob;
use rig::embeddings::{self, Embedding, EmbeddingError};
use serde::{Deserialize, Serialize};
use crate::{client::Client, types::errors::AwsSdkInvokeModelError};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbeddingRequest {
pub input_text: String,
pub dimensions: usize,
pub normalize: bool,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct EmbeddingResponse {
pub embedding: Vec<f64>,
pub input_text_token_count: usize,
}
/// `amazon.titan-embed-text-v1`
pub const AMAZON_TITAN_EMBED_TEXT_V1: &str = "amazon.titan-embed-text-v1";
/// `amazon.titan-embed-text-v2:0`
pub const AMAZON_TITAN_EMBED_TEXT_V2_0: &str = "amazon.titan-embed-text-v2:0";
/// `amazon.titan-embed-image-v1`
pub const AMAZON_TITAN_EMBED_IMAGE_V1: &str = "amazon.titan-embed-image-v1";
/// `cohere.embed-english-v3`
pub const COHERE_EMBED_ENGLISH_V3: &str = "cohere.embed-english-v3";
/// `cohere.embed-multilingual-v3`
pub const COHERE_EMBED_MULTILINGUAL_V3: &str = "cohere.embed-multilingual-v3";
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
model: String,
ndims: Option<usize>,
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
pub async fn document_to_embeddings(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, EmbeddingError> {
let input_document = serde_json::to_string(&request).map_err(EmbeddingError::JsonError)?;
let model_response = self
.client
.aws_client
.invoke_model()
.model_id(self.model.as_str())
.content_type("application/json")
.accept("application/json")
.body(Blob::new(input_document))
.send()
.await;
let response = model_response
.map_err(|sdk_error| AwsSdkInvokeModelError(sdk_error).into())
.map_err(|e: EmbeddingError| e)?;
let response_str = String::from_utf8(response.body.into_inner())
.map_err(|e| EmbeddingError::ResponseError(e.to_string()))?;
let result: EmbeddingResponse =
serde_json::from_str(&response_str).map_err(EmbeddingError::JsonError)?;
Ok(result)
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims.unwrap_or(0)
}
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect();
let mut results = Vec::new();
let mut errors = Vec::new();
let mut iterator = documents.into_iter();
while let Some(embedding) = iterator.next().map(|doc| async move {
let request = EmbeddingRequest {
input_text: doc.to_owned(),
dimensions: self.ndims(),
normalize: true,
};
self.document_to_embeddings(request)
.await
.map(|embeddings| Embedding {
document: doc.to_owned(),
vec: embeddings.embedding,
})
}) {
match embedding.await {
Ok(embedding) => results.push(embedding),
Err(err) => errors.push(err),
}
}
match errors.as_slice() {
[] => Ok(results),
[err, ..] => Err(EmbeddingError::ResponseError(err.to_string())),
}
}
}

65
rig-bedrock/src/image.rs Normal file
View File

@ -0,0 +1,65 @@
use crate::client::Client;
use crate::types::errors::AwsSdkInvokeModelError;
use crate::types::text_to_image::{TextToImageGeneration, TextToImageResponse};
use aws_smithy_types::Blob;
use rig::image_generation::{
self, ImageGenerationError, ImageGenerationRequest, ImageGenerationResponse,
};
/// `amazon.titan-image-generator-v1`
pub const AMAZON_TITAN_IMAGE_GENERATOR_V1: &str = "amazon.titan-image-generator-v1";
/// `amazon.titan-image-generator-v2:0`
pub const AMAZON_TITAN_IMAGE_GENERATOR_V2_0: &str = "amazon.titan-image-generator-v2:0";
/// `amazon.nova-canvas-v1:0`
pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
#[derive(Clone)]
pub struct ImageGenerationModel {
pub(crate) client: Client,
pub model: String,
}
impl ImageGenerationModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl image_generation::ImageGenerationModel for ImageGenerationModel {
type Response = TextToImageResponse;
async fn image_generation(
&self,
generation_request: ImageGenerationRequest,
) -> Result<ImageGenerationResponse<Self::Response>, ImageGenerationError> {
let mut request = TextToImageGeneration::new(generation_request.prompt);
request.width(generation_request.width);
request.height(generation_request.height);
let body = serde_json::to_string(&request)?;
let model_response = self
.client
.aws_client
.invoke_model()
.model_id(self.model.as_str())
.content_type("application/json")
.accept("application/json")
.body(Blob::new(body))
.send()
.await
.map_err(|sdk_error| {
Into::<ImageGenerationError>::into(AwsSdkInvokeModelError(sdk_error))
})?;
let response_str = String::from_utf8(model_response.body.into_inner())
.map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
let result: TextToImageResponse = serde_json::from_str(&response_str)
.map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
result.try_into()
}
}

6
rig-bedrock/src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod client;
pub mod completion;
pub mod embedding;
pub mod image;
pub mod streaming;
pub mod types;

View File

@ -0,0 +1,101 @@
use crate::types::completion_request::AwsCompletionRequest;
use crate::{completion::CompletionModel, types::errors::AwsSdkConverseStreamError};
use async_stream::stream;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult},
};
#[derive(Default)]
struct ToolCallState {
name: String,
id: String,
input_json: String,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: rig::completion::CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let request = AwsCompletionRequest(completion_request);
let mut converse_builder = self
.client
.aws_client
.converse_stream()
.model_id(self.model.as_str());
let tool_config = request.tools_config()?;
let prompt_with_history = request.prompt_with_history()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt())
.set_messages(Some(prompt_with_history));
let response = converse_builder.send().await.map_err(|sdk_error| {
Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
})?;
Ok(Box::pin(stream! {
let mut current_tool_call: Option<ToolCallState> = None;
let mut stream = response.stream;
while let Ok(Some(output)) = stream.recv().await {
match output {
aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
match delta {
aws_bedrock::ContentBlockDelta::Text(text) => {
if current_tool_call.is_none() {
yield Ok(StreamingChoice::Message(text))
}
},
aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
if let Some(ref mut tool_call) = current_tool_call {
tool_call.input_json.push_str(tool.input());
}
},
_ => {}
}
},
aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
current_tool_call = Some(ToolCallState {
name: tool_use.name,
id: tool_use.tool_use_id,
input_json: String::new(),
});
},
_ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
}
},
aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
match message_stop_event.stop_reason {
aws_bedrock::StopReason::ToolUse => {
if let Some(tool_call) = current_tool_call.take() {
let tool_input = serde_json::from_str(tool_call.input_json.as_str())?;
yield Ok(StreamingChoice::ToolCall(
tool_call.name,
tool_call.id,
tool_input
));
} else {
yield Err(CompletionError::ProviderError("Failed to call tool".into()))
}
}
aws_bedrock::StopReason::MaxTokens => {
yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
}
_ => {}
}
},
_ => {}
}
}
}))
}
}

View File

@ -0,0 +1,155 @@
use aws_sdk_bedrockruntime::operation::converse::ConverseOutput;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{AssistantContent, Text, ToolCall, ToolFunction},
OneOrMany,
};
use crate::types::message::RigMessage;
use super::json::AwsDocument;
use rig::completion;
#[derive(Clone)]
pub struct AwsConverseOutput(pub ConverseOutput);
impl TryFrom<AwsConverseOutput> for completion::CompletionResponse<AwsConverseOutput> {
type Error = CompletionError;
fn try_from(value: AwsConverseOutput) -> Result<Self, Self::Error> {
let message: RigMessage = value
.to_owned()
.0
.output
.ok_or(CompletionError::ProviderError(
"Model didn't return any output".into(),
))?
.as_message()
.map_err(|_| {
CompletionError::ProviderError(
"Failed to extract message from converse output".into(),
)
})?
.to_owned()
.try_into()?;
let choice = match message.0 {
completion::Message::Assistant { content } => Ok(content),
_ => Err(CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)),
}?;
if let Some(tool_use) = choice.iter().find_map(|content| match content {
AssistantContent::ToolCall(tool_call) => Some(tool_call.to_owned()),
_ => None,
}) {
return Ok(completion::CompletionResponse {
choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall {
id: tool_use.id,
function: ToolFunction {
name: tool_use.function.name,
arguments: tool_use.function.arguments,
},
})),
raw_response: value,
});
}
Ok(completion::CompletionResponse {
choice,
raw_response: value,
})
}
}
pub struct RigAssistantContent(pub AssistantContent);
impl TryFrom<aws_bedrock::ContentBlock> for RigAssistantContent {
type Error = CompletionError;
fn try_from(value: aws_bedrock::ContentBlock) -> Result<Self, Self::Error> {
match value {
aws_bedrock::ContentBlock::Text(text) => {
Ok(RigAssistantContent(AssistantContent::Text(Text { text })))
}
aws_bedrock::ContentBlock::ToolUse(call) => Ok(RigAssistantContent(
completion::AssistantContent::tool_call(
&call.tool_use_id,
&call.name,
AwsDocument(call.input).into(),
),
)),
_ => Err(CompletionError::ProviderError(
"AWS Bedrock returned unsupported ContentBlock".into(),
)),
}
}
}
impl TryFrom<RigAssistantContent> for aws_bedrock::ContentBlock {
type Error = CompletionError;
fn try_from(value: RigAssistantContent) -> Result<Self, Self::Error> {
match value.0 {
AssistantContent::Text(text) => Ok(aws_bedrock::ContentBlock::Text(text.text)),
AssistantContent::ToolCall(tool_call) => {
let doc: AwsDocument = tool_call.function.arguments.into();
Ok(aws_bedrock::ContentBlock::ToolUse(
aws_bedrock::ToolUseBlock::builder()
.tool_use_id(tool_call.id)
.name(tool_call.function.name)
.input(doc.0)
.build()
.map_err(|e| CompletionError::ProviderError(e.to_string()))?,
))
}
}
}
}
#[cfg(test)]
mod tests {
use crate::types::assistant_content::RigAssistantContent;
use super::AwsConverseOutput;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{completion, message::AssistantContent, OneOrMany};
#[test]
fn aws_converse_output_to_completion_response() {
let message = aws_bedrock::Message::builder()
.role(aws_bedrock::ConversationRole::Assistant)
.content(aws_bedrock::ContentBlock::Text("txt".into()))
.build()
.unwrap();
let output = aws_bedrock::ConverseOutput::Message(message);
let converse_output =
aws_sdk_bedrockruntime::operation::converse::ConverseOutput::builder()
.output(output)
.stop_reason(aws_bedrock::StopReason::EndTurn)
.build()
.unwrap();
let completion: Result<completion::CompletionResponse<AwsConverseOutput>, _> =
AwsConverseOutput(converse_output).try_into();
assert_eq!(completion.is_ok(), true);
let completion = completion.unwrap();
assert_eq!(
completion.choice,
OneOrMany::one(AssistantContent::Text("txt".into()))
);
}
#[test]
fn aws_content_block_to_assistant_content() {
let content_block = aws_bedrock::ContentBlock::Text("text".into());
let rig_assistant_content: Result<RigAssistantContent, _> = content_block.try_into();
assert_eq!(rig_assistant_content.is_ok(), true);
assert_eq!(
rig_assistant_content.unwrap().0,
AssistantContent::Text("text".into())
);
}
}

View File

@ -0,0 +1,85 @@
use crate::types::json::AwsDocument;
use crate::types::message::RigMessage;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use aws_sdk_bedrockruntime::types::{
InferenceConfiguration, SystemContentBlock, Tool, ToolConfiguration, ToolInputSchema,
ToolSpecification,
};
use rig::completion::{CompletionError, Message};
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
impl AwsCompletionRequest {
pub fn additional_params(&self) -> Option<aws_smithy_types::Document> {
self.0
.additional_params
.to_owned()
.map(|params| params.into())
.map(|doc: AwsDocument| doc.0)
}
pub fn inference_config(&self) -> Option<InferenceConfiguration> {
let mut inference_configuration = InferenceConfiguration::builder();
if let Some(temperature) = &self.0.temperature {
inference_configuration =
inference_configuration.set_temperature(Some(*temperature as f32));
}
if let Some(max_tokens) = &self.0.max_tokens {
inference_configuration =
inference_configuration.set_max_tokens(Some(*max_tokens as i32));
}
Some(inference_configuration.build())
}
pub fn tools_config(&self) -> Result<Option<ToolConfiguration>, CompletionError> {
let mut tools = vec![];
for tool_definition in self.0.tools.iter() {
let doc: AwsDocument = tool_definition.parameters.clone().into();
let schema = ToolInputSchema::Json(doc.0);
let tool = Tool::ToolSpec(
ToolSpecification::builder()
.name(tool_definition.name.clone())
.set_description(Some(tool_definition.description.clone()))
.set_input_schema(Some(schema))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?,
);
tools.push(tool);
}
if !tools.is_empty() {
let config = ToolConfiguration::builder()
.set_tools(Some(tools))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn system_prompt(&self) -> Option<Vec<SystemContentBlock>> {
self.0
.preamble
.to_owned()
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
}
pub fn prompt_with_history(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
let mut chat_history = self.0.chat_history.to_owned();
let prompt_with_context = self.0.prompt_with_context();
let mut full_history: Vec<Message> = Vec::new();
full_history.append(&mut chat_history);
full_history.push(prompt_with_context);
full_history
.into_iter()
.map(|message| RigMessage(message).try_into())
.collect::<Result<Vec<aws_bedrock::Message>, _>>()
}
}

View File

@ -0,0 +1,153 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{ContentFormat, Document},
};
pub(crate) use crate::types::media_types::RigDocumentMediaType;
use base64::{prelude::BASE64_STANDARD, Engine};
#[derive(Clone)]
pub struct RigDocument(pub Document);
impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
type Error = CompletionError;
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
let maybe_format = value
.0
.media_type
.map(|doc| RigDocumentMediaType(doc).try_into());
let format = match maybe_format {
Some(Ok(document_format)) => Ok(Some(document_format)),
Some(Err(err)) => Err(err),
None => Ok(None),
}?;
let document_data = BASE64_STANDARD
.decode(value.0.data)
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
let data = aws_smithy_types::Blob::new(document_data);
let document_source = aws_bedrock::DocumentSource::Bytes(data);
let result = aws_bedrock::DocumentBlock::builder()
.source(document_source)
.name("Document")
.set_format(format)
.build()
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
Ok(result)
}
}
impl TryFrom<aws_bedrock::DocumentBlock> for RigDocument {
type Error = CompletionError;
fn try_from(value: aws_bedrock::DocumentBlock) -> Result<Self, Self::Error> {
let media_type: RigDocumentMediaType = value.format.try_into()?;
let media_type = media_type.0;
let data = match value.source {
Some(aws_bedrock::DocumentSource::Bytes(blob)) => {
let encoded_data = BASE64_STANDARD.encode(blob.into_inner());
Ok(encoded_data)
}
_ => Err(CompletionError::ProviderError(
"Document source is missing".into(),
)),
}?;
Ok(RigDocument(Document {
data,
format: Some(ContentFormat::Base64),
media_type: Some(media_type),
}))
}
}
#[cfg(test)]
mod tests {
use aws_sdk_bedrockruntime::types as aws_bedrock;
use base64::{prelude::BASE64_STANDARD, Engine};
use rig::{
completion::CompletionError,
message::{ContentFormat, Document, DocumentMediaType},
};
use crate::types::document::RigDocument;
#[test]
fn test_document_to_aws_document() {
let rig_document = RigDocument(Document {
data: "data".into(),
format: Some(ContentFormat::Base64),
media_type: Some(DocumentMediaType::PDF),
});
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
assert_eq!(aws_document.is_ok(), true);
let aws_document = aws_document.unwrap();
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
let aws_document_bytes = aws_document
.source()
.unwrap()
.as_bytes()
.unwrap()
.as_ref()
.to_owned();
assert_eq!(aws_document_bytes, document_data)
}
#[test]
fn test_unsupported_document_to_aws_document() {
let rig_document = RigDocument(Document {
data: "data".into(),
format: Some(ContentFormat::Base64),
media_type: Some(DocumentMediaType::Javascript),
});
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
assert_eq!(
aws_document.err().unwrap().to_string(),
CompletionError::ProviderError(
"Unsupported media type application/x-javascript".into()
)
.to_string()
)
}
#[test]
fn test_aws_document_to_rig_document() {
let data = aws_smithy_types::Blob::new("document_data");
let document_source = aws_bedrock::DocumentSource::Bytes(data);
let aws_document = aws_bedrock::DocumentBlock::builder()
.format(aws_bedrock::DocumentFormat::Pdf)
.name("Document")
.source(document_source)
.build()
.unwrap();
let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
assert_eq!(rig_document.is_ok(), true);
let rig_document = rig_document.unwrap().0;
assert_eq!(rig_document.media_type.unwrap(), DocumentMediaType::PDF)
}
#[test]
fn test_unsupported_aws_document_to_rig_document() {
let data = aws_smithy_types::Blob::new("document_data");
let document_source = aws_bedrock::DocumentSource::Bytes(data);
let aws_document = aws_bedrock::DocumentBlock::builder()
.format(aws_bedrock::DocumentFormat::Xlsx)
.name("Document")
.source(document_source)
.build()
.unwrap();
let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
assert_eq!(rig_document.is_ok(), false);
assert_eq!(
rig_document.err().unwrap().to_string(),
CompletionError::ProviderError("Unsupported media type xlsx".into()).to_string()
)
}
}

View File

@ -0,0 +1,81 @@
use aws_sdk_bedrockruntime::config::http::HttpResponse;
use aws_sdk_bedrockruntime::error::SdkError;
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelError;
use rig::completion::CompletionError;
use rig::embeddings::EmbeddingError;
use rig::image_generation::ImageGenerationError;
pub struct AwsSdkInvokeModelError(pub SdkError<InvokeModelError, HttpResponse>);
impl AwsSdkInvokeModelError {
pub fn into_service_error(self) -> String {
let error: String = match self.0.into_service_error() {
InvokeModelError::ModelTimeoutException(e) => e.message.unwrap_or("The request took too long to process. Processing time exceeded the model timeout length.".into()),
InvokeModelError::AccessDeniedException(e) => e.message.unwrap_or("The request is denied because you do not have sufficient permissions to perform the requested action.".into()),
InvokeModelError::ResourceNotFoundException(e) => e.message.unwrap_or("The specified resource ARN was not found.".into()),
InvokeModelError::ThrottlingException(e) => e.message.unwrap_or("Your request was denied due to exceeding the account quotas for Amazon Bedrock.".into()),
InvokeModelError::ServiceUnavailableException(e) => e.message.unwrap_or("The service isn't currently available.".into()),
InvokeModelError::InternalServerException(e) => e.message.unwrap_or("An internal server error occurred.".into()),
InvokeModelError::ValidationException(e) => e.message.unwrap_or("The input fails to satisfy the constraints specified by Amazon Bedrock.".into()),
InvokeModelError::ModelNotReadyException(e) => e.message.unwrap_or("The model specified in the request is not ready to serve inference requests. The AWS SDK will automatically retry the operation up to 5 times.".into()),
InvokeModelError::ModelErrorException(e) => e.message.unwrap_or("The request failed due to an error while processing the model.".into()),
InvokeModelError::ServiceQuotaExceededException(e) => e.message.unwrap_or("Your request exceeds the service quota for your account.".into()),
_ => "An unexpected error occurred. Verify Internet connection or AWS keys".into(),
};
error
}
}
impl From<AwsSdkInvokeModelError> for ImageGenerationError {
fn from(value: AwsSdkInvokeModelError) -> Self {
ImageGenerationError::ProviderError(value.into_service_error())
}
}
impl From<AwsSdkInvokeModelError> for EmbeddingError {
fn from(value: AwsSdkInvokeModelError) -> Self {
EmbeddingError::ProviderError(value.into_service_error())
}
}
pub struct AwsSdkConverseError(pub SdkError<ConverseError, HttpResponse>);
impl From<AwsSdkConverseError> for CompletionError {
fn from(value: AwsSdkConverseError) -> Self {
let error: String = match value.0.into_service_error() {
ConverseError::ModelTimeoutException(e) => e.message.unwrap_or("The request took too long to process. Processing time exceeded the model timeout length.".into()),
ConverseError::AccessDeniedException(e) => e.message.unwrap_or("The request is denied because you do not have sufficient permissions to perform the requested action.".into()),
ConverseError::ResourceNotFoundException(e) => e.message.unwrap_or("The specified resource ARN was not found.".into()),
ConverseError::ThrottlingException(e) => e.message.unwrap_or("Your request was denied due to exceeding the account quotas for AWS Bedrock.".into()),
ConverseError::ServiceUnavailableException(e) => e.message.unwrap_or("The service isn't currently available.".into()),
ConverseError::InternalServerException(e) => e.message.unwrap_or("An internal server error occurred.".into()),
ConverseError::ValidationException(e) => e.message.unwrap_or("The input fails to satisfy the constraints specified by AWS Bedrock.".into()),
ConverseError::ModelNotReadyException(e) => e.message.unwrap_or("The model specified in the request is not ready to serve inference requests. The AWS SDK will automatically retry the operation up to 5 times.".into()),
ConverseError::ModelErrorException(e) => e.message.unwrap_or("The request failed due to an error while processing the model.".into()),
_ => String::from("An unexpected error occurred. Verify Internet connection or AWS keys")
};
CompletionError::ProviderError(error)
}
}
pub struct AwsSdkConverseStreamError(pub SdkError<ConverseStreamError, HttpResponse>);
impl From<AwsSdkConverseStreamError> for CompletionError {
fn from(value: AwsSdkConverseStreamError) -> Self {
let error: String = match value.0.into_service_error() {
ConverseStreamError::ModelTimeoutException(e) => e.message.unwrap(),
ConverseStreamError::AccessDeniedException(e) => e.message.unwrap(),
ConverseStreamError::ResourceNotFoundException(e) => e.message.unwrap(),
ConverseStreamError::ThrottlingException(e) => e.message.unwrap(),
ConverseStreamError::ServiceUnavailableException(e) => e.message.unwrap(),
ConverseStreamError::InternalServerException(e) => e.message.unwrap(),
ConverseStreamError::ModelStreamErrorException(e) => e.message.unwrap(),
ConverseStreamError::ValidationException(e) => e.message.unwrap(),
ConverseStreamError::ModelNotReadyException(e) => e.message.unwrap(),
ConverseStreamError::ModelErrorException(e) => e.message.unwrap(),
_ => "An unexpected error occurred. Verify Internet connection or AWS keys".into(),
};
CompletionError::ProviderError(error)
}
}

View File

@ -0,0 +1,129 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{ContentFormat, Image, ImageMediaType, MimeType},
};
use base64::{prelude::BASE64_STANDARD, Engine};
#[derive(Clone)]
pub struct RigImage(pub Image);
impl TryFrom<RigImage> for aws_bedrock::ImageBlock {
type Error = CompletionError;
fn try_from(image: RigImage) -> Result<Self, Self::Error> {
let maybe_format: Option<Result<aws_bedrock::ImageFormat, CompletionError>> =
image.0.media_type.map(|f| match f {
ImageMediaType::JPEG => Ok(aws_bedrock::ImageFormat::Jpeg),
ImageMediaType::PNG => Ok(aws_bedrock::ImageFormat::Png),
ImageMediaType::GIF => Ok(aws_bedrock::ImageFormat::Gif),
ImageMediaType::WEBP => Ok(aws_bedrock::ImageFormat::Webp),
e => Err(CompletionError::ProviderError(format!(
"Unsupported format {}",
e.to_mime_type()
))),
});
let format = match maybe_format {
Some(Ok(image_format)) => Ok(Some(image_format)),
Some(Err(err)) => Err(err),
None => Ok(None),
}?;
let img_data = BASE64_STANDARD
.decode(image.0.data)
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
let blob = aws_smithy_types::Blob::new(img_data);
let result = aws_bedrock::ImageBlock::builder()
.set_format(format)
.source(aws_bedrock::ImageSource::Bytes(blob))
.build()
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
Ok(result)
}
}
impl TryFrom<aws_bedrock::ImageBlock> for RigImage {
type Error = CompletionError;
fn try_from(image: aws_bedrock::ImageBlock) -> Result<Self, Self::Error> {
let media_type = match image.format {
aws_bedrock::ImageFormat::Gif => Ok(ImageMediaType::GIF),
aws_bedrock::ImageFormat::Jpeg => Ok(ImageMediaType::JPEG),
aws_bedrock::ImageFormat::Png => Ok(ImageMediaType::PNG),
aws_bedrock::ImageFormat::Webp => Ok(ImageMediaType::WEBP),
e => Err(CompletionError::ProviderError(format!(
"Unsupported format {}",
e
))),
}?;
let data = match image.source {
Some(aws_bedrock::ImageSource::Bytes(blob)) => {
let encoded_img = BASE64_STANDARD.encode(blob.into_inner());
Ok(encoded_img)
}
_ => Err(CompletionError::ProviderError(
"Image source is missing".into(),
)),
}?;
Ok(RigImage(Image {
data,
format: Some(ContentFormat::Base64),
media_type: Some(media_type),
detail: None,
}))
}
}
#[cfg(test)]
mod tests {
use aws_sdk_bedrockruntime::types as aws_bedrock;
use base64::{prelude::BASE64_STANDARD, Engine};
use rig::{
completion::CompletionError,
message::{ContentFormat, Image, ImageMediaType},
};
use crate::types::image::RigImage;
#[test]
fn test_image_to_aws_image() {
let rig_image = RigImage(Image {
data: BASE64_STANDARD.encode("img_data"),
format: Some(ContentFormat::Base64),
media_type: Some(ImageMediaType::JPEG),
detail: None,
});
let aws_image: Result<aws_bedrock::ImageBlock, _> = rig_image.clone().try_into();
assert_eq!(aws_image.is_ok(), true);
let aws_image = aws_image.unwrap();
assert_eq!(aws_image.format, aws_bedrock::ImageFormat::Jpeg);
let img_data = BASE64_STANDARD.decode(rig_image.0.data).unwrap();
let aws_image_bytes = aws_image
.source()
.unwrap()
.as_bytes()
.unwrap()
.as_ref()
.to_owned();
assert_eq!(aws_image_bytes, img_data)
}
#[test]
fn test_unsupported_image_to_aws_image() {
let rig_image = RigImage(Image {
data: BASE64_STANDARD.encode("img_data"),
format: Some(ContentFormat::Base64),
media_type: Some(ImageMediaType::HEIC),
detail: None,
});
let aws_image: Result<aws_bedrock::ImageBlock, _> = rig_image.clone().try_into();
assert_eq!(
aws_image.err().unwrap().to_string(),
CompletionError::ProviderError("Unsupported format image/heic".into()).to_string()
)
}
}

View File

@ -0,0 +1,182 @@
use aws_smithy_types::{Document, Number};
use serde_json::{Map, Value};
use std::collections::HashMap;
#[derive(Debug)]
pub struct AwsDocument(pub Document);
impl From<AwsDocument> for Value {
fn from(value: AwsDocument) -> Self {
match value.0 {
Document::Object(obj) => {
let documents = obj
.into_iter()
.map(|(k, v)| (k, AwsDocument(v).into()))
.collect::<Map<_, _>>();
Value::Object(documents)
}
Document::Array(arr) => {
let documents = arr.into_iter().map(|v| AwsDocument(v).into()).collect();
Value::Array(documents)
}
Document::Number(Number::PosInt(number)) => {
Value::Number(serde_json::Number::from(number))
}
Document::Number(Number::NegInt(number)) => {
Value::Number(serde_json::Number::from(number))
}
Document::Number(Number::Float(number)) => match serde_json::Number::from_f64(number) {
Some(n) => Value::Number(n),
// https://www.rfc-editor.org/rfc/rfc7159
// Numeric values that cannot be represented in the grammar (such as Infinity and NaN) are not permitted.
None => Value::Null,
},
Document::String(s) => Value::String(s),
Document::Bool(b) => Value::Bool(b),
Document::Null => Value::Null,
}
}
}
impl From<Value> for AwsDocument {
fn from(value: Value) -> Self {
match value {
Value::Null => AwsDocument(Document::Null),
Value::Bool(b) => AwsDocument(Document::Bool(b)),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
match i > 0 {
true => AwsDocument(Document::Number(Number::PosInt(i as u64))),
false => AwsDocument(Document::Number(Number::NegInt(i))),
}
} else if let Some(f) = num.as_f64() {
AwsDocument(Document::Number(Number::Float(f)))
} else {
AwsDocument(Document::Null)
}
}
Value::String(s) => AwsDocument(Document::String(s)),
Value::Array(arr) => {
let documents = arr
.into_iter()
.map(|json| json.into())
.map(|aws: AwsDocument| aws.0)
.collect();
AwsDocument(Document::Array(documents))
}
Value::Object(obj) => {
let documents = obj
.into_iter()
.map(|(k, v)| {
let doc: AwsDocument = v.into();
(k, doc.0)
})
.collect::<HashMap<_, _>>();
AwsDocument(Document::Object(documents))
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use aws_smithy_types::{Document, Number};
use serde_json::Value;
use crate::types::json::AwsDocument;
#[test]
fn test_json_to_aws_document() {
let json = r#"
{
"type": "object",
"is_enabled": true,
"version": 42,
"fraction": 1.23,
"negative": -11,
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required":["x", "y", null]
}
"#;
let value: Value = serde_json::from_str(json).unwrap();
let document: AwsDocument = value.into();
println!("{:?}", document);
}
#[test]
fn test_aws_document_to_json() {
let document = AwsDocument(Document::Object(HashMap::from([
(
String::from("type"),
Document::String(String::from("object")),
),
(
String::from("version"),
Document::Number(Number::PosInt(42)),
),
(
String::from("fraction"),
Document::Number(Number::Float(1.23)),
),
(
String::from("negative"),
Document::Number(Number::NegInt(-11)),
),
(String::from("is_enabled"), Document::Bool(true)),
(
String::from("properties"),
Document::Object(HashMap::from([
(
String::from("x"),
Document::Object(HashMap::from([
(
String::from("type"),
Document::String(String::from("number")),
),
(
String::from("description"),
Document::String(String::from("The first number to add")),
),
])),
),
(
String::from("y"),
Document::Object(HashMap::from([
(
String::from("type"),
Document::String(String::from("number")),
),
(
String::from("description"),
Document::String(String::from("The second number to add")),
),
])),
),
])),
),
(
String::from("required"),
Document::Array(vec![
Document::String(String::from("x")),
Document::String(String::from("y")),
Document::Null,
]),
),
])));
let json: Value = document.into();
println!("{:?}", json);
}
}

View File

@ -0,0 +1,43 @@
use aws_sdk_bedrockruntime::types::DocumentFormat;
use rig::{
completion::CompletionError,
message::{DocumentMediaType, MimeType},
};
pub struct RigDocumentMediaType(pub DocumentMediaType);
impl TryFrom<RigDocumentMediaType> for DocumentFormat {
type Error = CompletionError;
fn try_from(value: RigDocumentMediaType) -> Result<Self, Self::Error> {
match value.0 {
DocumentMediaType::PDF => Ok(DocumentFormat::Pdf),
DocumentMediaType::TXT => Ok(DocumentFormat::Txt),
DocumentMediaType::HTML => Ok(DocumentFormat::Html),
DocumentMediaType::MARKDOWN => Ok(DocumentFormat::Md),
DocumentMediaType::CSV => Ok(DocumentFormat::Csv),
e => Err(CompletionError::ProviderError(format!(
"Unsupported media type {}",
e.to_mime_type()
))),
}
}
}
impl TryFrom<DocumentFormat> for RigDocumentMediaType {
type Error = CompletionError;
fn try_from(value: DocumentFormat) -> Result<Self, Self::Error> {
match value {
DocumentFormat::Csv => Ok(RigDocumentMediaType(DocumentMediaType::CSV)),
DocumentFormat::Html => Ok(RigDocumentMediaType(DocumentMediaType::HTML)),
DocumentFormat::Md => Ok(RigDocumentMediaType(DocumentMediaType::MARKDOWN)),
DocumentFormat::Pdf => Ok(RigDocumentMediaType(DocumentMediaType::PDF)),
DocumentFormat::Txt => Ok(RigDocumentMediaType(DocumentMediaType::TXT)),
e => Err(CompletionError::ProviderError(format!(
"Unsupported media type {}",
e
))),
}
}
}

View File

@ -0,0 +1,111 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{AssistantContent, Message, UserContent},
OneOrMany,
};
use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
pub struct RigMessage(pub Message);
impl TryFrom<RigMessage> for aws_bedrock::Message {
type Error = CompletionError;
fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
let result = match value.0 {
Message::User { content } => {
let message_content = content
.into_iter()
.map(|user_content| RigUserContent(user_content).try_into())
.collect::<Result<Vec<Vec<_>>, _>>()
.map_err(|e| CompletionError::RequestError(Box::new(e)))
.map(|nested| nested.into_iter().flatten().collect())?;
aws_bedrock::Message::builder()
.role(aws_bedrock::ConversationRole::User)
.set_content(Some(message_content))
.build()
.map_err(|e| CompletionError::RequestError(Box::new(e)))?
}
Message::Assistant { content } => aws_bedrock::Message::builder()
.role(aws_bedrock::ConversationRole::Assistant)
.set_content(Some(
content
.into_iter()
.map(|content| RigAssistantContent(content).try_into())
.collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
))
.build()
.map_err(|e| CompletionError::RequestError(Box::new(e)))?,
};
Ok(result)
}
}
impl TryFrom<aws_bedrock::Message> for RigMessage {
type Error = CompletionError;
fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
match message.role {
aws_bedrock::ConversationRole::Assistant => {
let assistant_content = message
.content
.into_iter()
.map(|c| c.try_into())
.collect::<Result<Vec<RigAssistantContent>, _>>()?
.into_iter()
.map(|rig_assistant_content| rig_assistant_content.0)
.collect::<Vec<AssistantContent>>();
let content = OneOrMany::many(assistant_content)
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
Ok(RigMessage(Message::Assistant { content }))
}
aws_bedrock::ConversationRole::User => {
let user_content = message
.content
.into_iter()
.map(|c| c.try_into())
.collect::<Result<Vec<RigUserContent>, _>>()?
.into_iter()
.map(|user_content| user_content.0)
.collect::<Vec<UserContent>>();
let content = OneOrMany::many(user_content)
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
Ok(RigMessage(Message::User { content }))
}
_ => Err(CompletionError::ProviderError(
"AWS Bedrock returned unsupported ConversationRole".into(),
)),
}
}
}
#[cfg(test)]
mod tests {
use crate::types::message::RigMessage;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
message::{Message, UserContent},
OneOrMany,
};
#[test]
fn message_to_aws_message() {
let message = Message::User {
content: OneOrMany::one(UserContent::Text("text".into())),
};
let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
assert_eq!(aws_message.is_ok(), true);
let aws_message = aws_message.unwrap();
assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
assert_eq!(
aws_message.content,
vec![aws_bedrock::ContentBlock::Text("text".into())]
);
}
}

View File

@ -0,0 +1,11 @@
pub(crate) mod assistant_content;
pub(crate) mod completion_request;
pub(crate) mod document;
pub(crate) mod errors;
pub(crate) mod image;
pub(crate) mod json;
pub(crate) mod media_types;
pub(crate) mod message;
pub(crate) mod text_to_image;
pub(crate) mod tool;
pub(crate) mod user_content;

View File

@ -0,0 +1,129 @@
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use rig::image_generation;
use rig::image_generation::ImageGenerationError;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
Standard,
Premium,
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageGenerationConfig {
// The quality of the image.
// Default: standard
pub quality: Option<ImageQuality>,
// The number of images to generate.
// Default: 1, Minimum: 1, Maximum: 5
#[serde(skip_serializing_if = "Option::is_none")]
pub number_of_images: Option<u32>,
// The height of the image in pixels.
pub height: Option<u32>,
// The width of the image in pixels.
pub width: Option<u32>,
// Specifies how strongly the generated image should adhere to the prompt. Use a lower value to introduce more randomness in the generation.
// Default: 8.0. Minimum: 1.1, Maximum: 10.0
#[serde(skip_serializing_if = "Option::is_none")]
pub cfg_scale: Option<f32>,
// Use to control and reproduce results. Determines the initial noise setting.
// Use the same seed and the same settings as a previous run to allow inference to create a similar image.
// Default: 42, Minimum: 0, Maximum: 2147483646
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
}
impl Default for ImageGenerationConfig {
fn default() -> Self {
ImageGenerationConfig {
quality: Some(ImageQuality::Standard),
number_of_images: Some(1),
height: Some(512),
width: Some(512),
cfg_scale: None,
seed: None,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextToImageParams {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub negative_text: Option<String>,
}
impl TextToImageParams {
pub fn new(text: String) -> Self {
Self {
text,
negative_text: None,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextToImageGeneration {
pub task_type: &'static str,
pub text_to_image_params: TextToImageParams,
pub image_generation_config: ImageGenerationConfig,
}
impl TextToImageGeneration {
pub(crate) fn new(text: String) -> TextToImageGeneration {
TextToImageGeneration {
task_type: "TEXT_IMAGE",
text_to_image_params: TextToImageParams::new(text),
image_generation_config: Default::default(),
}
}
pub fn height(&mut self, height: u32) -> &Self {
self.image_generation_config.height = Some(height);
self
}
pub fn width(&mut self, width: u32) -> &Self {
self.image_generation_config.width = Some(width);
self
}
}
#[derive(Clone, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct TextToImageResponse {
pub images: Option<Vec<String>>,
pub error: Option<String>,
}
impl TryFrom<TextToImageResponse>
for image_generation::ImageGenerationResponse<TextToImageResponse>
{
type Error = ImageGenerationError;
fn try_from(value: TextToImageResponse) -> Result<Self, Self::Error> {
if let Some(error) = value.error {
return Err(ImageGenerationError::ResponseError(error));
}
if let Some(images) = value.to_owned().images {
let data = BASE64_STANDARD
.decode(&images[0])
.expect("Could not decode image.");
return Ok(Self {
image: data,
response: value,
});
}
Err(ImageGenerationError::ResponseError(
"Malformed response from model".to_string(),
))
}
}

View File

@ -0,0 +1,124 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{Text, ToolResultContent},
};
use serde_json::Value;
use super::{image::RigImage, json::AwsDocument};
pub struct RigToolResultContent(pub ToolResultContent);
impl TryFrom<RigToolResultContent> for aws_bedrock::ToolResultContentBlock {
type Error = CompletionError;
fn try_from(value: RigToolResultContent) -> Result<Self, Self::Error> {
match value.0 {
ToolResultContent::Text(text) => {
Ok(aws_bedrock::ToolResultContentBlock::Text(text.text))
}
ToolResultContent::Image(image) => {
let image = RigImage(image).try_into()?;
Ok(aws_bedrock::ToolResultContentBlock::Image(image))
}
}
}
}
impl TryFrom<aws_bedrock::ToolResultContentBlock> for RigToolResultContent {
type Error = CompletionError;
fn try_from(value: aws_bedrock::ToolResultContentBlock) -> Result<Self, Self::Error> {
match value {
aws_bedrock::ToolResultContentBlock::Image(image) => {
let image: RigImage = image.try_into()?;
Ok(RigToolResultContent(ToolResultContent::Image(image.0)))
}
aws_bedrock::ToolResultContentBlock::Json(document) => {
let json: Value = AwsDocument(document).into();
Ok(RigToolResultContent(ToolResultContent::Text(Text {
text: json.to_string(),
})))
}
aws_bedrock::ToolResultContentBlock::Text(text) => {
Ok(RigToolResultContent(ToolResultContent::Text(Text { text })))
}
_ => Err(CompletionError::ProviderError(
"ToolResultContentBlock contains unsupported variant".into(),
)),
}
}
}
#[cfg(test)]
mod tests {
use aws_sdk_bedrockruntime::types as aws_bedrock;
use base64::{prelude::BASE64_STANDARD, Engine};
use rig::{
completion::CompletionError,
message::{ContentFormat, Image, ImageMediaType, Text, ToolResultContent},
};
use crate::types::tool::RigToolResultContent;
#[test]
fn rig_tool_text_to_aws_tool() {
let tool = RigToolResultContent(ToolResultContent::Text(Text { text: "42".into() }));
let aws_tool: Result<aws_bedrock::ToolResultContentBlock, _> = tool.try_into();
assert_eq!(aws_tool.is_ok(), true);
assert_eq!(
String::from(aws_tool.unwrap().as_text().unwrap()),
String::from("42")
);
}
#[test]
fn rig_tool_image_to_aws_tool() {
let image = Image {
data: BASE64_STANDARD.encode("img_data"),
format: Some(ContentFormat::Base64),
media_type: Some(ImageMediaType::JPEG),
detail: None,
};
let tool = RigToolResultContent(ToolResultContent::Image(image));
let aws_tool: Result<aws_bedrock::ToolResultContentBlock, _> = tool.try_into();
assert_eq!(aws_tool.is_ok(), true);
assert_eq!(aws_tool.unwrap().is_image(), true)
}
#[test]
fn aws_tool_to_rig_tool() {
let aws_tool = aws_bedrock::ToolResultContentBlock::Text("txt".into());
let tool: Result<RigToolResultContent, _> = aws_tool.try_into();
assert_eq!(tool.is_ok(), true);
let tool = match tool.unwrap().0 {
ToolResultContent::Text(text) => Ok(text),
_ => Err("tool doesn't contain text"),
};
assert_eq!(tool.is_ok(), true);
assert_eq!(tool.unwrap().text, String::from("txt"))
}
#[test]
fn aws_tool_to_unsupported_rig_tool() {
let document_source =
aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new("document_data"));
let aws_document = aws_bedrock::DocumentBlock::builder()
.format(aws_bedrock::DocumentFormat::Pdf)
.name("Document")
.source(document_source)
.build()
.unwrap();
let aws_tool = aws_bedrock::ToolResultContentBlock::Document(aws_document);
let tool: Result<RigToolResultContent, _> = aws_tool.try_into();
assert_eq!(tool.is_ok(), false);
assert_eq!(
tool.err().unwrap().to_string(),
CompletionError::ProviderError(
"ToolResultContentBlock contains unsupported variant".into()
)
.to_string()
)
}
}

View File

@ -0,0 +1,173 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{Text, ToolResult, ToolResultContent, UserContent},
OneOrMany,
};
use super::{document::RigDocument, image::RigImage, tool::RigToolResultContent};
pub struct RigUserContent(pub UserContent);
impl TryFrom<aws_bedrock::ContentBlock> for RigUserContent {
type Error = CompletionError;
fn try_from(value: aws_bedrock::ContentBlock) -> Result<Self, Self::Error> {
match value {
aws_bedrock::ContentBlock::Text(text) => {
Ok(RigUserContent(UserContent::Text(Text { text })))
}
aws_bedrock::ContentBlock::ToolResult(tool_result) => {
let tool_result_contents = tool_result
.content
.into_iter()
.map(|tool| tool.try_into())
.collect::<Result<Vec<RigToolResultContent>, _>>()?
.into_iter()
.map(|rt| rt.0)
.collect::<Vec<ToolResultContent>>();
let tool_results = OneOrMany::many(tool_result_contents).map_err(|_| {
CompletionError::ProviderError("ToolResult returned invalid response".into())
})?;
Ok(RigUserContent(UserContent::ToolResult(ToolResult {
id: tool_result.tool_use_id,
content: tool_results,
})))
}
aws_bedrock::ContentBlock::Document(document) => {
let doc: RigDocument = document.try_into()?;
Ok(RigUserContent(UserContent::Document(doc.0)))
}
aws_bedrock::ContentBlock::Image(image) => {
let image: RigImage = image.try_into()?;
Ok(RigUserContent(UserContent::Image(image.0)))
}
_ => Err(CompletionError::ProviderError(
"ToolResultContentBlock contains unsupported variant".into(),
)),
}
}
}
impl TryFrom<RigUserContent> for Vec<aws_bedrock::ContentBlock> {
type Error = CompletionError;
fn try_from(value: RigUserContent) -> Result<Self, Self::Error> {
match value.0 {
UserContent::Text(text) => Ok(vec![aws_bedrock::ContentBlock::Text(text.text)]),
UserContent::ToolResult(tool_result) => {
let builder = aws_bedrock::ToolResultBlock::builder()
.tool_use_id(tool_result.id)
.set_content(Some(
tool_result
.content
.into_iter()
.map(|tool| RigToolResultContent(tool).try_into())
.collect::<Result<Vec<aws_bedrock::ToolResultContentBlock>, _>>()?,
))
.build()
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
Ok(vec![aws_bedrock::ContentBlock::ToolResult(builder)])
}
UserContent::Image(image) => {
let image = RigImage(image).try_into()?;
Ok(vec![aws_bedrock::ContentBlock::Image(image)])
}
UserContent::Document(document) => {
let doc = RigDocument(document).try_into()?;
// AWS documentations: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html
// In the content field of the Message object, you must also include a text field with a prompt related to the document.
Ok(vec![
aws_bedrock::ContentBlock::Text("Use provided document".to_string()),
aws_bedrock::ContentBlock::Document(doc),
])
}
UserContent::Audio(_) => Err(CompletionError::ProviderError(
"Audio is not supported".into(),
)),
}
}
}
#[cfg(test)]
mod tests {
use crate::types::user_content::RigUserContent;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{ToolResultContent, UserContent},
OneOrMany,
};
#[test]
fn aws_content_block_to_user_content() {
let cb = aws_bedrock::ContentBlock::Text("42".into());
let user_content: Result<RigUserContent, _> = cb.try_into();
assert_eq!(user_content.is_ok(), true);
let content = match user_content.unwrap().0 {
rig::message::UserContent::Text(text) => Ok(text),
_ => Err("Invalid content type"),
};
assert_eq!(content.is_ok(), true);
assert_eq!(content.unwrap().text, "42")
}
#[test]
fn aws_content_block_tool_to_user_content() {
let cb = aws_bedrock::ContentBlock::ToolResult(
aws_bedrock::ToolResultBlock::builder()
.tool_use_id("123")
.content(aws_bedrock::ToolResultContentBlock::Text("content".into()))
.build()
.unwrap(),
);
let user_content: Result<RigUserContent, _> = cb.try_into();
assert_eq!(user_content.is_ok(), true);
let content = match user_content.unwrap().0 {
rig::message::UserContent::ToolResult(tool_result) => Ok(tool_result),
_ => Err("Invalid content type"),
};
assert_eq!(content.is_ok(), true);
let content = content.unwrap();
assert_eq!(content.id, "123");
assert_eq!(
content.content,
OneOrMany::one(ToolResultContent::Text("content".into()))
)
}
#[test]
fn aws_unsupported_content_block_to_user_content() {
let cb = aws_bedrock::ContentBlock::GuardContent(
aws_bedrock::GuardrailConverseContentBlock::Text(
aws_bedrock::GuardrailConverseTextBlock::builder()
.text("stuff")
.build()
.unwrap(),
),
);
let user_content: Result<RigUserContent, _> = cb.try_into();
assert_eq!(user_content.is_ok(), false);
assert_eq!(
user_content.err().unwrap().to_string(),
CompletionError::ProviderError(
"ToolResultContentBlock contains unsupported variant".into()
)
.to_string()
)
}
#[test]
fn user_content_to_aws_content_block() {
let uc = RigUserContent(UserContent::Text("txt".into()));
let aws_content_blocks: Result<Vec<aws_bedrock::ContentBlock>, _> = uc.try_into();
assert_eq!(aws_content_blocks.is_ok(), true);
let aws_content_blocks = aws_content_blocks.unwrap();
assert_eq!(
aws_content_blocks,
vec![aws_bedrock::ContentBlock::Text("txt".into())]
);
}
}

View File

@ -41,6 +41,7 @@ pub trait ImageGeneration<M: ImageGenerationModel> {
> + Send;
}
#[derive(Debug)]
pub struct ImageGenerationResponse<T> {
pub image: Vec<u8>,
pub response: T,

View File

@ -151,7 +151,7 @@ pub enum Message {
)]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: String,

View File

@ -102,7 +102,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
.tools
.call(&name, params.to_string())
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
.map_err(|e| std::io::Error::other(e.to_string()))?;
println!("\nResult: {}", res);
}
Err(e) => {

View File

@ -277,7 +277,7 @@ impl Neo4jClient {
/// ### Arguments
/// * `index_name` - The name of the index to create.
/// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
/// the label `:Movie`, pass "Movie" as the `node_label` parameter.
/// the label `:Movie`, pass "Movie" as the `node_label` parameter.
/// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
///
pub async fn create_vector_index(

View File

@ -211,7 +211,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for Neo4jVec
/// #### Generic Type Parameters
///
/// - `T`: The type used to deserialize the result from the Neo4j query.
/// It must implement the `serde::Deserialize` trait.
/// It must implement the `serde::Deserialize` trait.
///
/// #### Returns
///

View File

@ -29,7 +29,7 @@ impl<M: EmbeddingModel> QdrantVectorStore<M> {
/// * `client` - Qdrant client instance
/// * `model` - Embedding model instance
/// * `query_params` - Search parameters for vector queries
/// Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
/// Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
Self {
client,