mirror of https://github.com/0xplaygrounds/rig
Merge branch 'main' of https://github.com/0xPlaygrounds/rig into fix/multiple-tool-calling
This commit is contained in:
commit
d618d1a435
|
@ -1134,6 +1134,30 @@ dependencies = [
|
|||
"uuid 1.16.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-bedrockruntime"
|
||||
version = "1.77.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4198493316dab97e1fed7716f3823462b73a34c518f4ee7b9799921645e232e5"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-http",
|
||||
"aws-smithy-json",
|
||||
"aws-smithy-runtime",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"aws-types",
|
||||
"bytes",
|
||||
"fastrand",
|
||||
"http 0.2.12",
|
||||
"once_cell",
|
||||
"regex-lite",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-dynamodb"
|
||||
version = "1.69.0"
|
||||
|
@ -1258,12 +1282,24 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-smithy-eventstream"
|
||||
version = "0.60.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c45d3dddac16c5c59d553ece225a88870cf81b7b813c9cc17b78cf4685eac7a"
|
||||
dependencies = [
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
"crc32fast",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-smithy-http"
|
||||
version = "0.62.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166"
|
||||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
|
@ -8584,6 +8620,27 @@ version = "0.8.50"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
|
||||
|
||||
[[package]]
|
||||
name = "rig-bedrock"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
"aws-config",
|
||||
"aws-sdk-bedrockruntime",
|
||||
"aws-smithy-types",
|
||||
"base64 0.22.1",
|
||||
"reqwest 0.12.15",
|
||||
"rig-core 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rig-derive",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rig-core"
|
||||
version = "0.11.0"
|
||||
|
@ -8617,6 +8674,27 @@ dependencies = [
|
|||
"worker",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rig-core"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff893305131b471009ab11df388612beb603ed94bb12412c256fe197b7591aa6"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"futures",
|
||||
"glob",
|
||||
"mime_guess",
|
||||
"ordered-float",
|
||||
"reqwest 0.12.15",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rig-derive"
|
||||
version = "0.1.0"
|
||||
|
@ -8634,7 +8712,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"ethers",
|
||||
"reqwest 0.12.15",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -8649,7 +8727,7 @@ version = "0.1.4"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"fastembed",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -8667,7 +8745,7 @@ dependencies = [
|
|||
"futures",
|
||||
"httpmock",
|
||||
"lancedb",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
|
@ -8681,7 +8759,7 @@ dependencies = [
|
|||
"futures",
|
||||
"httpmock",
|
||||
"mongodb",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"testcontainers",
|
||||
|
@ -8698,7 +8776,7 @@ dependencies = [
|
|||
"futures",
|
||||
"httpmock",
|
||||
"neo4rs",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"term_size",
|
||||
|
@ -8718,7 +8796,7 @@ dependencies = [
|
|||
"httpmock",
|
||||
"log",
|
||||
"pgvector",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
|
@ -8737,7 +8815,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"httpmock",
|
||||
"qdrant-client",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"testcontainers",
|
||||
|
@ -8752,7 +8830,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"chrono",
|
||||
"httpmock",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -8769,7 +8847,7 @@ name = "rig-surrealdb"
|
|||
version = "0.1.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"rig-core",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"surrealdb",
|
||||
|
|
|
@ -9,6 +9,8 @@ members = [
|
|||
"rig-qdrant",
|
||||
"rig-core/rig-core-derive",
|
||||
"rig-sqlite",
|
||||
"rig-eternalai", "rig-fastembed",
|
||||
"rig-surrealdb",
|
||||
"rig-eternalai",
|
||||
"rig-fastembed",
|
||||
"rig-bedrock",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "rig-bedrock"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
description = "AWS Bedrock model provider for Rig integration."
|
||||
|
||||
[dependencies]
|
||||
rig-core = { version = "0.11.0", features = ["image"] }
|
||||
rig-derive = { path = "../rig-core/rig-core-derive", version = "0.1.0" }
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
||||
schemars = "0.8.16"
|
||||
tracing = "0.1.40"
|
||||
aws-config = { version = "1.6.0", features = ["behavior-version-latest"] }
|
||||
aws-sdk-bedrockruntime = "1.77.0"
|
||||
aws-smithy-types = "1.3.0"
|
||||
base64 = "0.22.1"
|
||||
async-stream = "0.3.6"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.75"
|
||||
tokio = { version = "1.34.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.18"
|
||||
reqwest = { version = "0.12.12", features = ["json", "stream"] }
|
|
@ -0,0 +1,23 @@
|
|||
## Rig-Bedrock
|
||||
This companion crate integrates AWS Bedrock as model provider with Rig.
|
||||
|
||||
## Usage
|
||||
|
||||
Add the companion crate to your `Cargo.toml`, along with the rig-core crate:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
rig-bedrock = "0.1.0"
|
||||
rig-core = "0.9.1"
|
||||
```
|
||||
|
||||
You can also run `cargo add rig-bedrock rig-core` to add the most recent versions of the dependencies to your project.
|
||||
|
||||
See the [`/examples`](./examples) folder for usage examples.
|
||||
|
||||
Make sure to have AWS credentials env vars loaded before starting client such as:
|
||||
```shell
|
||||
export AWS_DEFAULT_REGION=us-east-1
|
||||
export AWS_SECRET_ACCESS_KEY=.......
|
||||
export AWS_ACCESS_KEY_ID=......
|
||||
```
|
|
@ -0,0 +1,121 @@
|
|||
use rig::{agent::AgentBuilder, completion::Prompt, loaders::FileLoader};
|
||||
use rig_bedrock::{
|
||||
client::{Client, ClientBuilder},
|
||||
completion::AMAZON_NOVA_LITE,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
mod common;
|
||||
|
||||
/// Runs 4 agents based on AWS Bedrock (derived from the agent_with_grok example)
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
info!("Running basic agent");
|
||||
basic().await?;
|
||||
|
||||
info!("\nRunning agent with tools");
|
||||
tools().await?;
|
||||
|
||||
info!("\nRunning agent with loaders");
|
||||
loaders().await?;
|
||||
|
||||
info!("\nRunning agent with context");
|
||||
context().await?;
|
||||
|
||||
info!("\n\nAll agents ran successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn client() -> Client {
|
||||
ClientBuilder::new().build().await
|
||||
}
|
||||
|
||||
async fn partial_agent() -> AgentBuilder<rig_bedrock::completion::CompletionModel> {
|
||||
let client = client().await;
|
||||
client.agent(AMAZON_NOVA_LITE)
|
||||
}
|
||||
|
||||
/// Create an AWS Bedrock agent with a system prompt
|
||||
async fn basic() -> Result<(), anyhow::Error> {
|
||||
let agent = partial_agent()
|
||||
.await
|
||||
.preamble("Answer with json format only")
|
||||
.build();
|
||||
|
||||
let response = agent.prompt("Describe solar system").await?;
|
||||
info!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create an AWS Bedrock with tools
|
||||
async fn tools() -> Result<(), anyhow::Error> {
|
||||
let calculator_agent = partial_agent()
|
||||
.await
|
||||
.preamble("You must only do math by using a tool.")
|
||||
.max_tokens(1024)
|
||||
.tool(common::Adder)
|
||||
.build();
|
||||
|
||||
info!(
|
||||
"Calculator Agent: add 400 and 20\nResult: {}",
|
||||
calculator_agent.prompt("add 400 and 20").await?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn context() -> Result<(), anyhow::Error> {
|
||||
let model = client().await.completion_model(AMAZON_NOVA_LITE);
|
||||
|
||||
// Create an agent with multiple context documents
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble("Answer the question")
|
||||
.context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||
.context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||
.context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;
|
||||
|
||||
info!("What does \"glarb-glarb\" mean?\n{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Based upon the `loaders` example
|
||||
///
|
||||
/// This example loads in all the rust examples from the rig-core crate and uses them as\\
|
||||
/// context for the agent
|
||||
async fn loaders() -> Result<(), anyhow::Error> {
|
||||
let model = client().await.completion_model(AMAZON_NOVA_LITE);
|
||||
|
||||
// Load in all the rust examples
|
||||
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
|
||||
.read_with_path()
|
||||
.ignore_errors()
|
||||
.into_iter();
|
||||
|
||||
// Create an agent with multiple context documents
|
||||
let agent = examples
|
||||
.fold(AgentBuilder::new(model), |builder, (path, content)| {
|
||||
builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
|
||||
})
|
||||
.preamble("Answer the question")
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = agent
|
||||
.prompt("Which rust example is best suited for the operation 1 + 2")
|
||||
.await?;
|
||||
|
||||
info!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
use std::{
|
||||
error::Error,
|
||||
fmt::{Display, Formatter},
|
||||
};
|
||||
|
||||
use rig::{completion::ToolDefinition, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MathError {}
|
||||
|
||||
impl Display for MathError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Math error")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MathError {}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
use reqwest::Client;
|
||||
|
||||
use rig::{
|
||||
completion::{message::Document, Prompt},
|
||||
message::{ContentFormat, DocumentMediaType},
|
||||
};
|
||||
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
use tracing::info;
|
||||
|
||||
const DOCUMENT_URL: &str =
|
||||
"https://www.inf.ed.ac.uk/teaching/courses/ai2/module4/small_slides/small-agents.pdf";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.without_time()
|
||||
.with_level(false)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = ClientBuilder::new().build().await;
|
||||
let agent = client
|
||||
.agent(AMAZON_NOVA_LITE)
|
||||
.preamble("Describe this document but respond with json format only")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
let reqwest_client = Client::new();
|
||||
let response = reqwest_client.get(DOCUMENT_URL).send().await?;
|
||||
|
||||
info!("Status: {}", response.status().as_str());
|
||||
info!("Content Type: {:?}", response.headers().get("Content-Type"));
|
||||
|
||||
let document_bytes = response.bytes().await?;
|
||||
let bytes_base64 = BASE64_STANDARD.encode(document_bytes);
|
||||
|
||||
let document = Document {
|
||||
data: bytes_base64,
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(DocumentMediaType::PDF),
|
||||
};
|
||||
|
||||
let response = agent.prompt(document).await?;
|
||||
info!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
use rig::Embed;
|
||||
use rig_bedrock::{client::ClientBuilder, embedding::AMAZON_TITAN_EMBED_TEXT_V2_0};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(rig_derive::Embed, Debug)]
|
||||
struct Greetings {
|
||||
#[embed]
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = ClientBuilder::new().build().await;
|
||||
|
||||
let embeddings = client
|
||||
.embeddings(AMAZON_TITAN_EMBED_TEXT_V2_0, 256)
|
||||
.document(Greetings {
|
||||
message: "aa".to_string(),
|
||||
})?
|
||||
.document(Greetings {
|
||||
message: "bb".to_string(),
|
||||
})?
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
info!("{:?}", embeddings);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
|
||||
struct Person {
|
||||
pub first_name: Option<String>,
|
||||
pub last_name: Option<String>,
|
||||
pub job: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = ClientBuilder::new().build().await;
|
||||
|
||||
let data_extractor = client.extractor::<Person>(AMAZON_NOVA_LITE).build();
|
||||
let person = data_extractor
|
||||
.extract("Hello my name is John Doe! I am a software engineer.")
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"AWS Bedrock: {}",
|
||||
serde_json::to_string_pretty(&person).unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
use rig::image_generation::ImageGenerationModel;
|
||||
use rig_bedrock::client::ClientBuilder;
|
||||
use rig_bedrock::image::AMAZON_NOVA_CANVAS;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_PATH: &str = "./output.png";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let client = ClientBuilder::new().build().await;
|
||||
let image_generation_model = client.image_generation_model(AMAZON_NOVA_CANVAS);
|
||||
let response = image_generation_model
|
||||
.image_generation_request()
|
||||
.prompt("A castle sitting upon a large mountain, overlooking the water.")
|
||||
.width(512)
|
||||
.height(512)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
// save image
|
||||
let mut file = File::create_new(Path::new(&DEFAULT_PATH)).expect("Failed to create file");
|
||||
let _ = file.write(&response.unwrap().image);
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
use reqwest::Client;
|
||||
|
||||
use rig::{
|
||||
completion::{message::Image, Prompt},
|
||||
message::{ContentFormat, ImageMediaType},
|
||||
};
|
||||
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
use tracing::info;
|
||||
|
||||
const IMAGE_URL: &str = "https://playgrounds.network/assets/PG-Logo.png";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = ClientBuilder::new().build().await;
|
||||
let agent = client
|
||||
.agent(AMAZON_NOVA_LITE)
|
||||
.preamble("You are an image describer.")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
// Grab image and convert to base64
|
||||
let reqwest_client = Client::new();
|
||||
let image_bytes = reqwest_client.get(IMAGE_URL).send().await?.bytes().await?;
|
||||
let image_base64 = BASE64_STANDARD.encode(image_bytes);
|
||||
|
||||
// Compose `Image` for prompt
|
||||
let image = Image {
|
||||
data: image_base64,
|
||||
media_type: Some(ImageMediaType::PNG),
|
||||
format: Some(ContentFormat::Base64),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = agent.prompt(image).await?;
|
||||
info!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
use std::vec;
|
||||
|
||||
use rig::{
|
||||
completion::Prompt, embeddings::EmbeddingsBuilder,
|
||||
vector_store::in_memory_store::InMemoryVectorStore, Embed,
|
||||
};
|
||||
use rig_bedrock::{
|
||||
client::ClientBuilder, completion::AMAZON_NOVA_LITE, embedding::AMAZON_TITAN_EMBED_TEXT_V2_0,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use tracing::info;
|
||||
|
||||
// Data to be RAG-ed.
|
||||
// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition`
|
||||
// and tag that field with `#[embed]`.
|
||||
#[derive(rig_derive::Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)]
|
||||
struct WordDefinition {
|
||||
id: String,
|
||||
word: String,
|
||||
#[embed]
|
||||
definitions: Vec<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = ClientBuilder::new().build().await;
|
||||
let embedding_model = client.embedding_model(AMAZON_TITAN_EMBED_TEXT_V2_0, 256);
|
||||
|
||||
// Generate embeddings for the definitions of all the documents using the specified embedding model.
|
||||
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
.documents(vec![
|
||||
WordDefinition {
|
||||
id: "doc0".to_string(),
|
||||
word: "flurbo".to_string(),
|
||||
definitions: vec![
|
||||
"1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(),
|
||||
"2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
|
||||
]
|
||||
},
|
||||
WordDefinition {
|
||||
id: "doc1".to_string(),
|
||||
word: "glarb-glarb".to_string(),
|
||||
definitions: vec![
|
||||
"1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
|
||||
"2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
|
||||
]
|
||||
},
|
||||
WordDefinition {
|
||||
id: "doc2".to_string(),
|
||||
word: "linglingdong".to_string(),
|
||||
definitions: vec![
|
||||
"1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
|
||||
"2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
|
||||
]
|
||||
},
|
||||
])?
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
// Create vector store with the embeddings
|
||||
let vector_store = InMemoryVectorStore::from_documents(embeddings);
|
||||
|
||||
// Create vector store index
|
||||
let index = vector_store.index(embedding_model);
|
||||
|
||||
let rag_agent = client.agent(AMAZON_NOVA_LITE)
|
||||
.preamble("
|
||||
You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||
You will find additional non-standard word definitions that could be useful below.
|
||||
")
|
||||
.dynamic_context(1, index)
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await?;
|
||||
|
||||
info!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create streaming agent with a single context prompt
|
||||
let agent = ClientBuilder::new()
|
||||
.build()
|
||||
.await
|
||||
.agent(AMAZON_NOVA_LITE)
|
||||
.preamble("Be precise and concise.")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
// Stream the response and print chunks as they arrive
|
||||
let mut stream = agent
|
||||
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||
.await?;
|
||||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||
use rig_bedrock::{client::ClientBuilder, completion::AMAZON_NOVA_LITE};
|
||||
mod common;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt().init();
|
||||
// Create agent with a single context prompt and two tools
|
||||
let agent = ClientBuilder::new()
|
||||
.build()
|
||||
.await
|
||||
.agent(AMAZON_NOVA_LITE)
|
||||
.preamble(
|
||||
"You are a calculator here to help the user perform arithmetic
|
||||
operations. Use the tools provided to answer the user's question.
|
||||
make your answer long, so we can test the streaming functionality,
|
||||
like 20 words",
|
||||
)
|
||||
.max_tokens(1024)
|
||||
.tool(common::Adder)
|
||||
.build();
|
||||
|
||||
println!("Calculate 2 + 5");
|
||||
let mut stream = agent.stream_prompt("Calculate 2 + 5").await?;
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
use crate::image::ImageGenerationModel;
|
||||
use crate::{completion::CompletionModel, embedding::EmbeddingModel};
|
||||
use aws_config::{BehaviorVersion, Region};
|
||||
use rig::{agent::AgentBuilder, embeddings, extractor::ExtractorBuilder, Embed};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const DEFAULT_AWS_REGION: &str = "us-east-1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClientBuilder<'a> {
|
||||
region: &'a str,
|
||||
}
|
||||
|
||||
/// Create a new Bedrock client using the builder <br>
|
||||
impl<'a> ClientBuilder<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
region: DEFAULT_AWS_REGION,
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure to verify model and region [compatibility]
|
||||
///
|
||||
/// [compatibility]: https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html
|
||||
pub fn region(mut self, region: &'a str) -> Self {
|
||||
self.region = region;
|
||||
self
|
||||
}
|
||||
|
||||
/// Make sure you have permissions to access [Amazon Bedrock foundation model]
|
||||
///
|
||||
/// [ Amazon Bedrock foundation model]: <https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html>
|
||||
pub async fn build(self) -> Client {
|
||||
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
|
||||
.region(Region::new(String::from(self.region)))
|
||||
.load()
|
||||
.await;
|
||||
let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
|
||||
Client { aws_client: client }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientBuilder<'_> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
pub(crate) aws_client: aws_sdk_bedrockruntime::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn embedding_model(&self, model: &str, ndims: usize) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, Some(ndims))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn embeddings<D: Embed>(
|
||||
&self,
|
||||
model: &str,
|
||||
ndims: usize,
|
||||
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
embeddings::EmbeddingsBuilder::new(self.embedding_model(model, ndims))
|
||||
}
|
||||
|
||||
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
|
||||
ImageGenerationModel::new(self.clone(), model)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
//! All supported models <https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html>
|
||||
use crate::{
|
||||
client::Client,
|
||||
types::{
|
||||
assistant_content::AwsConverseOutput, completion_request::AwsCompletionRequest,
|
||||
errors::AwsSdkConverseError,
|
||||
},
|
||||
};
|
||||
|
||||
use rig::completion::{self, CompletionError};
|
||||
|
||||
/// `amazon.nova-canvas-v1:0`
|
||||
pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
|
||||
/// `amazon.nova-lite-v1:0`
|
||||
pub const AMAZON_NOVA_LITE: &str = "amazon.nova-lite-v1:0";
|
||||
/// `amazon.nova-micro-v1:0`
|
||||
pub const AMAZON_NOVA_MICRO: &str = "amazon.nova-micro-v1:0";
|
||||
/// `amazon.nova-pro-v1:0`
|
||||
pub const AMAZON_NOVA_PRO: &str = "amazon.nova-pro-v1:0";
|
||||
/// `amazon.rerank-v1:0`
|
||||
pub const AMAZON_RERANK_1_0: &str = "amazon.rerank-v1:0";
|
||||
/// `amazon.titan-text-express-v1`
|
||||
pub const AMAZON_TITAN_TEXT_EXPRESS_V1: &str = "amazon.titan-text-express-v1";
|
||||
/// `amazon.titan-text-lite-v1`
|
||||
pub const AMAZON_TITAN_TEXT_LITE_V1: &str = "amazon.titan-text-lite-v1";
|
||||
/// `amazon.titan-text-premier-v1:0`
|
||||
pub const AMAZON_TITAN_TEXT_PREMIER_V1_0: &str = "amazon.titan-text-premier-v1:0";
|
||||
|
||||
/// `anthropic.claude-3-haiku-20240307-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_HAIKU: &str = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
/// `anthropic.claude-3-opus-20240229-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_OPUS: &str = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
/// `anthropic.claude-3-sonnet-20240229-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_SONNET: &str = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
/// `anthropic.claude-3-5-haiku-20241022-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_5_HAIKU: &str = "anthropic.claude-3-5-haiku-20241022-v1:0";
|
||||
/// `anthropic.claude-3-5-sonnet-20241022-v2:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_5_SONNET_V2: &str = "anthropic.claude-3-5-sonnet-20241022-v2:0";
|
||||
/// `anthropic.claude-3-5-sonnet-20240620-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_5_SONNET: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
/// `anthropic.claude-3-7-sonnet-20250219-v1:0`
|
||||
pub const ANTHROPIC_CLAUDE_3_7_SONNET: &str = "anthropic.claude-3-7-sonnet-20250219-v1:0";
|
||||
/// `cohere.command-light-text-v14`
|
||||
pub const COHERE_COMMAND_LIGHT_TEXT: &str = "cohere.command-light-text-v14";
|
||||
/// `cohere.command-r-plus-v1:0`
|
||||
pub const COHERE_COMMAND_R_PLUS: &str = "cohere.command-r-plus-v1:0";
|
||||
/// `cohere.command-r-v1:0`
|
||||
pub const COHERE_COMMAND_R: &str = "cohere.command-r-v1:0";
|
||||
/// `cohere.command-text-v14`
|
||||
pub const COHERE_COMMAND: &str = "cohere.command-text-v14";
|
||||
/// `cohere.rerank-v3-5:0`
|
||||
pub const COHERE_RERANK_V3_5: &str = "cohere.rerank-v3-5:0";
|
||||
/// `luma.ray-v2:0`
|
||||
pub const LUMA_RAY_V2_0: &str = "luma.ray-v2:0";
|
||||
/// `meta.llama3-8b-instruct-v1:0`
|
||||
pub const LLAMA_3_8B_INSTRUCT: &str = "meta.llama3-8b-instruct-v1:0";
|
||||
/// `meta.llama3-70b-instruct-v1:0`
|
||||
pub const LLAMA_3_70B_INSTRUCT: &str = "meta.llama3-70b-instruct-v1:0";
|
||||
/// `meta.llama3-1-8b-instruct-v1:0`
|
||||
pub const LLAMA_3_1_8B_INSTRUCT: &str = "meta.llama3-1-8b-instruct-v1:0";
|
||||
/// `meta.llama3-1-70b-instruct-v1:0`
|
||||
pub const LLAMA_3_1_70B_INSTRUCT: &str = "meta.llama3-1-70b-instruct-v1:0";
|
||||
/// `meta.llama3-1-405b-instruct-v1:0`
|
||||
pub const LLAMA_3_1_405B_INSTRUCT: &str = "meta.llama3-1-405b-instruct-v1:0";
|
||||
/// `meta.llama3-2-1b-instruct-v1:0`
|
||||
pub const LLAMA_3_2_1B_INSTRUCT: &str = "meta.llama3-2-1b-instruct-v1:0";
|
||||
/// `meta.llama3-2-3b-instruct-v1:0`
|
||||
pub const LLAMA_3_2_3B_INSTRUCT: &str = "meta.llama3-2-3b-instruct-v1:0";
|
||||
/// `meta.llama3-2-11b-instruct-v1:0`
|
||||
pub const LLAMA_3_2_11B_INSTRUCT: &str = "meta.llama3-2-11b-instruct-v1:0";
|
||||
/// `meta.llama3-2-90b-instruct-v1:0`
|
||||
pub const LLAMA_3_2_90B_INSTRUCT: &str = "meta.llama3-2-90b-instruct-v1:0";
|
||||
/// `meta.llama3-3-70b-instruct-v1:0`
|
||||
pub const LLAMA_3_2_70B_INSTRUCT: &str = "meta.llama3-3-70b-instruct-v1:0";
|
||||
/// `mistral.mistral-7b-instruct-v0:2`
|
||||
pub const MISTRAL_7B_INSTRUCT: &str = "mistral.mistral-7b-instruct-v0:2";
|
||||
/// `mistral.mistral-large-2402-v1:0`
|
||||
pub const MISTRAL_LARGE_24_02: &str = "mistral.mistral-large-2402-v1:0";
|
||||
/// `mistral.mistral-large-2407-v1:0`
|
||||
pub const MISTRAL_LARGE_24_07: &str = "mistral.mistral-large-2407-v1:0";
|
||||
/// `mistral.mistral-small-2402-v1:0`
|
||||
pub const MISTRAL_SMALL_24_02: &str = "mistral.mistral-small-2402-v1:0";
|
||||
/// `mistral.mixtral-8x7b-instruct-v0:1`
|
||||
pub const MISTRAL_MIXTRAL_8X7B_INSTRUCT_V0: &str = "mistral.mixtral-8x7b-instruct-v0:1";
|
||||
/// `stability.sd3-5-large-v1:0`
|
||||
pub const STABILITY_SD3_5_LARGE: &str = "stability.sd3-5-large-v1:0";
|
||||
/// `ai21.jamba-1-5-large-v1:0`
|
||||
pub const JAMBA_1_5_LARGE: &str = "ai21.jamba-1-5-large-v1:0";
|
||||
/// `ai21.jamba-1-5-mini-v1:0`
|
||||
pub const JAMBA_1_5_MINI: &str = "ai21.jamba-1-5-mini-v1:0";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
pub(crate) client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = AwsConverseOutput;
|
||||
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<AwsConverseOutput>, CompletionError> {
|
||||
let request = AwsCompletionRequest(completion_request);
|
||||
|
||||
let mut converse_builder = self
|
||||
.client
|
||||
.aws_client
|
||||
.converse()
|
||||
.model_id(self.model.as_str());
|
||||
|
||||
let tool_config = request.tools_config()?;
|
||||
let prompt_with_history = request.prompt_with_history()?;
|
||||
converse_builder = converse_builder
|
||||
.set_additional_model_request_fields(request.additional_params())
|
||||
.set_inference_config(request.inference_config())
|
||||
.set_tool_config(tool_config)
|
||||
.set_system(request.system_prompt())
|
||||
.set_messages(Some(prompt_with_history));
|
||||
|
||||
let response = converse_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|sdk_error| Into::<CompletionError>::into(AwsSdkConverseError(sdk_error)))?;
|
||||
|
||||
AwsConverseOutput(response).try_into()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
use aws_smithy_types::Blob;
|
||||
use rig::embeddings::{self, Embedding, EmbeddingError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{client::Client, types::errors::AwsSdkInvokeModelError};
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EmbeddingRequest {
|
||||
pub input_text: String,
|
||||
pub dimensions: usize,
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EmbeddingResponse {
|
||||
pub embedding: Vec<f64>,
|
||||
pub input_text_token_count: usize,
|
||||
}
|
||||
|
||||
/// `amazon.titan-embed-text-v1`
|
||||
pub const AMAZON_TITAN_EMBED_TEXT_V1: &str = "amazon.titan-embed-text-v1";
|
||||
/// `amazon.titan-embed-text-v2:0`
|
||||
pub const AMAZON_TITAN_EMBED_TEXT_V2_0: &str = "amazon.titan-embed-text-v2:0";
|
||||
/// `amazon.titan-embed-image-v1`
|
||||
pub const AMAZON_TITAN_EMBED_IMAGE_V1: &str = "amazon.titan-embed-image-v1";
|
||||
/// `cohere.embed-english-v3`
|
||||
pub const COHERE_EMBED_ENGLISH_V3: &str = "cohere.embed-english-v3";
|
||||
/// `cohere.embed-multilingual-v3`
|
||||
pub const COHERE_EMBED_MULTILINGUAL_V3: &str = "cohere.embed-multilingual-v3";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
client: Client,
|
||||
model: String,
|
||||
ndims: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn document_to_embeddings(
|
||||
&self,
|
||||
request: EmbeddingRequest,
|
||||
) -> Result<EmbeddingResponse, EmbeddingError> {
|
||||
let input_document = serde_json::to_string(&request).map_err(EmbeddingError::JsonError)?;
|
||||
|
||||
let model_response = self
|
||||
.client
|
||||
.aws_client
|
||||
.invoke_model()
|
||||
.model_id(self.model.as_str())
|
||||
.content_type("application/json")
|
||||
.accept("application/json")
|
||||
.body(Blob::new(input_document))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let response = model_response
|
||||
.map_err(|sdk_error| AwsSdkInvokeModelError(sdk_error).into())
|
||||
.map_err(|e: EmbeddingError| e)?;
|
||||
|
||||
let response_str = String::from_utf8(response.body.into_inner())
|
||||
.map_err(|e| EmbeddingError::ResponseError(e.to_string()))?;
|
||||
|
||||
let result: EmbeddingResponse =
|
||||
serde_json::from_str(&response_str).map_err(EmbeddingError::JsonError)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 1024;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
self.ndims.unwrap_or(0)
|
||||
}
|
||||
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String> + Send,
|
||||
) -> Result<Vec<Embedding>, EmbeddingError> {
|
||||
let documents: Vec<_> = documents.into_iter().collect();
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let mut iterator = documents.into_iter();
|
||||
while let Some(embedding) = iterator.next().map(|doc| async move {
|
||||
let request = EmbeddingRequest {
|
||||
input_text: doc.to_owned(),
|
||||
dimensions: self.ndims(),
|
||||
normalize: true,
|
||||
};
|
||||
self.document_to_embeddings(request)
|
||||
.await
|
||||
.map(|embeddings| Embedding {
|
||||
document: doc.to_owned(),
|
||||
vec: embeddings.embedding,
|
||||
})
|
||||
}) {
|
||||
match embedding.await {
|
||||
Ok(embedding) => results.push(embedding),
|
||||
Err(err) => errors.push(err),
|
||||
}
|
||||
}
|
||||
|
||||
match errors.as_slice() {
|
||||
[] => Ok(results),
|
||||
[err, ..] => Err(EmbeddingError::ResponseError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
use crate::client::Client;
|
||||
use crate::types::errors::AwsSdkInvokeModelError;
|
||||
use crate::types::text_to_image::{TextToImageGeneration, TextToImageResponse};
|
||||
use aws_smithy_types::Blob;
|
||||
use rig::image_generation::{
|
||||
self, ImageGenerationError, ImageGenerationRequest, ImageGenerationResponse,
|
||||
};
|
||||
|
||||
/// `amazon.titan-image-generator-v1`
|
||||
pub const AMAZON_TITAN_IMAGE_GENERATOR_V1: &str = "amazon.titan-image-generator-v1";
|
||||
/// `amazon.titan-image-generator-v2:0`
|
||||
pub const AMAZON_TITAN_IMAGE_GENERATOR_V2_0: &str = "amazon.titan-image-generator-v2:0";
|
||||
/// `amazon.nova-canvas-v1:0`
|
||||
pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ImageGenerationModel {
|
||||
pub(crate) client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl ImageGenerationModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl image_generation::ImageGenerationModel for ImageGenerationModel {
|
||||
type Response = TextToImageResponse;
|
||||
|
||||
async fn image_generation(
|
||||
&self,
|
||||
generation_request: ImageGenerationRequest,
|
||||
) -> Result<ImageGenerationResponse<Self::Response>, ImageGenerationError> {
|
||||
let mut request = TextToImageGeneration::new(generation_request.prompt);
|
||||
request.width(generation_request.width);
|
||||
request.height(generation_request.height);
|
||||
|
||||
let body = serde_json::to_string(&request)?;
|
||||
let model_response = self
|
||||
.client
|
||||
.aws_client
|
||||
.invoke_model()
|
||||
.model_id(self.model.as_str())
|
||||
.content_type("application/json")
|
||||
.accept("application/json")
|
||||
.body(Blob::new(body))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|sdk_error| {
|
||||
Into::<ImageGenerationError>::into(AwsSdkInvokeModelError(sdk_error))
|
||||
})?;
|
||||
|
||||
let response_str = String::from_utf8(model_response.body.into_inner())
|
||||
.map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
|
||||
|
||||
let result: TextToImageResponse = serde_json::from_str(&response_str)
|
||||
.map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
|
||||
|
||||
result.try_into()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod image;
|
||||
pub mod streaming;
|
||||
pub mod types;
|
|
@ -0,0 +1,101 @@
|
|||
use crate::types::completion_request::AwsCompletionRequest;
|
||||
use crate::{completion::CompletionModel, types::errors::AwsSdkConverseStreamError};
|
||||
use async_stream::stream;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
struct ToolCallState {
|
||||
name: String,
|
||||
id: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: rig::completion::CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
let request = AwsCompletionRequest(completion_request);
|
||||
|
||||
let mut converse_builder = self
|
||||
.client
|
||||
.aws_client
|
||||
.converse_stream()
|
||||
.model_id(self.model.as_str());
|
||||
|
||||
let tool_config = request.tools_config()?;
|
||||
let prompt_with_history = request.prompt_with_history()?;
|
||||
converse_builder = converse_builder
|
||||
.set_additional_model_request_fields(request.additional_params())
|
||||
.set_inference_config(request.inference_config())
|
||||
.set_tool_config(tool_config)
|
||||
.set_system(request.system_prompt())
|
||||
.set_messages(Some(prompt_with_history));
|
||||
|
||||
let response = converse_builder.send().await.map_err(|sdk_error| {
|
||||
Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
|
||||
})?;
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let mut current_tool_call: Option<ToolCallState> = None;
|
||||
let mut stream = response.stream;
|
||||
while let Ok(Some(output)) = stream.recv().await {
|
||||
match output {
|
||||
aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
|
||||
let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
|
||||
match delta {
|
||||
aws_bedrock::ContentBlockDelta::Text(text) => {
|
||||
if current_tool_call.is_none() {
|
||||
yield Ok(StreamingChoice::Message(text))
|
||||
}
|
||||
},
|
||||
aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
|
||||
if let Some(ref mut tool_call) = current_tool_call {
|
||||
tool_call.input_json.push_str(tool.input());
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
|
||||
match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
|
||||
aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
|
||||
current_tool_call = Some(ToolCallState {
|
||||
name: tool_use.name,
|
||||
id: tool_use.tool_use_id,
|
||||
input_json: String::new(),
|
||||
});
|
||||
},
|
||||
_ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
|
||||
}
|
||||
},
|
||||
aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
|
||||
match message_stop_event.stop_reason {
|
||||
aws_bedrock::StopReason::ToolUse => {
|
||||
if let Some(tool_call) = current_tool_call.take() {
|
||||
let tool_input = serde_json::from_str(tool_call.input_json.as_str())?;
|
||||
yield Ok(StreamingChoice::ToolCall(
|
||||
tool_call.name,
|
||||
tool_call.id,
|
||||
tool_input
|
||||
));
|
||||
} else {
|
||||
yield Err(CompletionError::ProviderError("Failed to call tool".into()))
|
||||
}
|
||||
}
|
||||
aws_bedrock::StopReason::MaxTokens => {
|
||||
yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
use aws_sdk_bedrockruntime::operation::converse::ConverseOutput;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{AssistantContent, Text, ToolCall, ToolFunction},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use crate::types::message::RigMessage;
|
||||
|
||||
use super::json::AwsDocument;
|
||||
use rig::completion;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwsConverseOutput(pub ConverseOutput);
|
||||
|
||||
impl TryFrom<AwsConverseOutput> for completion::CompletionResponse<AwsConverseOutput> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: AwsConverseOutput) -> Result<Self, Self::Error> {
|
||||
let message: RigMessage = value
|
||||
.to_owned()
|
||||
.0
|
||||
.output
|
||||
.ok_or(CompletionError::ProviderError(
|
||||
"Model didn't return any output".into(),
|
||||
))?
|
||||
.as_message()
|
||||
.map_err(|_| {
|
||||
CompletionError::ProviderError(
|
||||
"Failed to extract message from converse output".into(),
|
||||
)
|
||||
})?
|
||||
.to_owned()
|
||||
.try_into()?;
|
||||
|
||||
let choice = match message.0 {
|
||||
completion::Message::Assistant { content } => Ok(content),
|
||||
_ => Err(CompletionError::ResponseError(
|
||||
"Response contained no message or tool call (empty)".to_owned(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
if let Some(tool_use) = choice.iter().find_map(|content| match content {
|
||||
AssistantContent::ToolCall(tool_call) => Some(tool_call.to_owned()),
|
||||
_ => None,
|
||||
}) {
|
||||
return Ok(completion::CompletionResponse {
|
||||
choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall {
|
||||
id: tool_use.id,
|
||||
function: ToolFunction {
|
||||
name: tool_use.function.name,
|
||||
arguments: tool_use.function.arguments,
|
||||
},
|
||||
})),
|
||||
raw_response: value,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice,
|
||||
raw_response: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RigAssistantContent(pub AssistantContent);
|
||||
|
||||
impl TryFrom<aws_bedrock::ContentBlock> for RigAssistantContent {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: aws_bedrock::ContentBlock) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
aws_bedrock::ContentBlock::Text(text) => {
|
||||
Ok(RigAssistantContent(AssistantContent::Text(Text { text })))
|
||||
}
|
||||
aws_bedrock::ContentBlock::ToolUse(call) => Ok(RigAssistantContent(
|
||||
completion::AssistantContent::tool_call(
|
||||
&call.tool_use_id,
|
||||
&call.name,
|
||||
AwsDocument(call.input).into(),
|
||||
),
|
||||
)),
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"AWS Bedrock returned unsupported ContentBlock".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RigAssistantContent> for aws_bedrock::ContentBlock {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigAssistantContent) -> Result<Self, Self::Error> {
|
||||
match value.0 {
|
||||
AssistantContent::Text(text) => Ok(aws_bedrock::ContentBlock::Text(text.text)),
|
||||
AssistantContent::ToolCall(tool_call) => {
|
||||
let doc: AwsDocument = tool_call.function.arguments.into();
|
||||
Ok(aws_bedrock::ContentBlock::ToolUse(
|
||||
aws_bedrock::ToolUseBlock::builder()
|
||||
.tool_use_id(tool_call.id)
|
||||
.name(tool_call.function.name)
|
||||
.input(doc.0)
|
||||
.build()
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::types::assistant_content::RigAssistantContent;
|
||||
|
||||
use super::AwsConverseOutput;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use rig::{completion, message::AssistantContent, OneOrMany};
|
||||
|
||||
#[test]
|
||||
fn aws_converse_output_to_completion_response() {
|
||||
let message = aws_bedrock::Message::builder()
|
||||
.role(aws_bedrock::ConversationRole::Assistant)
|
||||
.content(aws_bedrock::ContentBlock::Text("txt".into()))
|
||||
.build()
|
||||
.unwrap();
|
||||
let output = aws_bedrock::ConverseOutput::Message(message);
|
||||
let converse_output =
|
||||
aws_sdk_bedrockruntime::operation::converse::ConverseOutput::builder()
|
||||
.output(output)
|
||||
.stop_reason(aws_bedrock::StopReason::EndTurn)
|
||||
.build()
|
||||
.unwrap();
|
||||
let completion: Result<completion::CompletionResponse<AwsConverseOutput>, _> =
|
||||
AwsConverseOutput(converse_output).try_into();
|
||||
assert_eq!(completion.is_ok(), true);
|
||||
let completion = completion.unwrap();
|
||||
assert_eq!(
|
||||
completion.choice,
|
||||
OneOrMany::one(AssistantContent::Text("txt".into()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aws_content_block_to_assistant_content() {
|
||||
let content_block = aws_bedrock::ContentBlock::Text("text".into());
|
||||
let rig_assistant_content: Result<RigAssistantContent, _> = content_block.try_into();
|
||||
assert_eq!(rig_assistant_content.is_ok(), true);
|
||||
assert_eq!(
|
||||
rig_assistant_content.unwrap().0,
|
||||
AssistantContent::Text("text".into())
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
use crate::types::json::AwsDocument;
|
||||
use crate::types::message::RigMessage;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use aws_sdk_bedrockruntime::types::{
|
||||
InferenceConfiguration, SystemContentBlock, Tool, ToolConfiguration, ToolInputSchema,
|
||||
ToolSpecification,
|
||||
};
|
||||
use rig::completion::{CompletionError, Message};
|
||||
|
||||
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
|
||||
|
||||
impl AwsCompletionRequest {
|
||||
pub fn additional_params(&self) -> Option<aws_smithy_types::Document> {
|
||||
self.0
|
||||
.additional_params
|
||||
.to_owned()
|
||||
.map(|params| params.into())
|
||||
.map(|doc: AwsDocument| doc.0)
|
||||
}
|
||||
|
||||
pub fn inference_config(&self) -> Option<InferenceConfiguration> {
|
||||
let mut inference_configuration = InferenceConfiguration::builder();
|
||||
|
||||
if let Some(temperature) = &self.0.temperature {
|
||||
inference_configuration =
|
||||
inference_configuration.set_temperature(Some(*temperature as f32));
|
||||
}
|
||||
|
||||
if let Some(max_tokens) = &self.0.max_tokens {
|
||||
inference_configuration =
|
||||
inference_configuration.set_max_tokens(Some(*max_tokens as i32));
|
||||
}
|
||||
|
||||
Some(inference_configuration.build())
|
||||
}
|
||||
|
||||
pub fn tools_config(&self) -> Result<Option<ToolConfiguration>, CompletionError> {
|
||||
let mut tools = vec![];
|
||||
for tool_definition in self.0.tools.iter() {
|
||||
let doc: AwsDocument = tool_definition.parameters.clone().into();
|
||||
let schema = ToolInputSchema::Json(doc.0);
|
||||
let tool = Tool::ToolSpec(
|
||||
ToolSpecification::builder()
|
||||
.name(tool_definition.name.clone())
|
||||
.set_description(Some(tool_definition.description.clone()))
|
||||
.set_input_schema(Some(schema))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::RequestError(e.into()))?,
|
||||
);
|
||||
tools.push(tool);
|
||||
}
|
||||
|
||||
if !tools.is_empty() {
|
||||
let config = ToolConfiguration::builder()
|
||||
.set_tools(Some(tools))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::RequestError(e.into()))?;
|
||||
|
||||
Ok(Some(config))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system_prompt(&self) -> Option<Vec<SystemContentBlock>> {
|
||||
self.0
|
||||
.preamble
|
||||
.to_owned()
|
||||
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
|
||||
}
|
||||
|
||||
pub fn prompt_with_history(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
|
||||
let mut chat_history = self.0.chat_history.to_owned();
|
||||
let prompt_with_context = self.0.prompt_with_context();
|
||||
|
||||
let mut full_history: Vec<Message> = Vec::new();
|
||||
full_history.append(&mut chat_history);
|
||||
full_history.push(prompt_with_context);
|
||||
|
||||
full_history
|
||||
.into_iter()
|
||||
.map(|message| RigMessage(message).try_into())
|
||||
.collect::<Result<Vec<aws_bedrock::Message>, _>>()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ContentFormat, Document},
|
||||
};
|
||||
|
||||
pub(crate) use crate::types::media_types::RigDocumentMediaType;
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RigDocument(pub Document);
|
||||
|
||||
impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
|
||||
let maybe_format = value
|
||||
.0
|
||||
.media_type
|
||||
.map(|doc| RigDocumentMediaType(doc).try_into());
|
||||
|
||||
let format = match maybe_format {
|
||||
Some(Ok(document_format)) => Ok(Some(document_format)),
|
||||
Some(Err(err)) => Err(err),
|
||||
None => Ok(None),
|
||||
}?;
|
||||
|
||||
let document_data = BASE64_STANDARD
|
||||
.decode(value.0.data)
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||
let data = aws_smithy_types::Blob::new(document_data);
|
||||
let document_source = aws_bedrock::DocumentSource::Bytes(data);
|
||||
|
||||
let result = aws_bedrock::DocumentBlock::builder()
|
||||
.source(document_source)
|
||||
.name("Document")
|
||||
.set_format(format)
|
||||
.build()
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<aws_bedrock::DocumentBlock> for RigDocument {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: aws_bedrock::DocumentBlock) -> Result<Self, Self::Error> {
|
||||
let media_type: RigDocumentMediaType = value.format.try_into()?;
|
||||
let media_type = media_type.0;
|
||||
|
||||
let data = match value.source {
|
||||
Some(aws_bedrock::DocumentSource::Bytes(blob)) => {
|
||||
let encoded_data = BASE64_STANDARD.encode(blob.into_inner());
|
||||
Ok(encoded_data)
|
||||
}
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"Document source is missing".into(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
Ok(RigDocument(Document {
|
||||
data,
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(media_type),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ContentFormat, Document, DocumentMediaType},
|
||||
};
|
||||
|
||||
use crate::types::document::RigDocument;
|
||||
|
||||
#[test]
|
||||
fn test_document_to_aws_document() {
|
||||
let rig_document = RigDocument(Document {
|
||||
data: "data".into(),
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(DocumentMediaType::PDF),
|
||||
});
|
||||
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
||||
assert_eq!(aws_document.is_ok(), true);
|
||||
let aws_document = aws_document.unwrap();
|
||||
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
|
||||
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
|
||||
let aws_document_bytes = aws_document
|
||||
.source()
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.to_owned();
|
||||
assert_eq!(aws_document_bytes, document_data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_document_to_aws_document() {
|
||||
let rig_document = RigDocument(Document {
|
||||
data: "data".into(),
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(DocumentMediaType::Javascript),
|
||||
});
|
||||
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
|
||||
assert_eq!(
|
||||
aws_document.err().unwrap().to_string(),
|
||||
CompletionError::ProviderError(
|
||||
"Unsupported media type application/x-javascript".into()
|
||||
)
|
||||
.to_string()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aws_document_to_rig_document() {
|
||||
let data = aws_smithy_types::Blob::new("document_data");
|
||||
let document_source = aws_bedrock::DocumentSource::Bytes(data);
|
||||
let aws_document = aws_bedrock::DocumentBlock::builder()
|
||||
.format(aws_bedrock::DocumentFormat::Pdf)
|
||||
.name("Document")
|
||||
.source(document_source)
|
||||
.build()
|
||||
.unwrap();
|
||||
let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
|
||||
assert_eq!(rig_document.is_ok(), true);
|
||||
let rig_document = rig_document.unwrap().0;
|
||||
assert_eq!(rig_document.media_type.unwrap(), DocumentMediaType::PDF)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_aws_document_to_rig_document() {
|
||||
let data = aws_smithy_types::Blob::new("document_data");
|
||||
let document_source = aws_bedrock::DocumentSource::Bytes(data);
|
||||
let aws_document = aws_bedrock::DocumentBlock::builder()
|
||||
.format(aws_bedrock::DocumentFormat::Xlsx)
|
||||
.name("Document")
|
||||
.source(document_source)
|
||||
.build()
|
||||
.unwrap();
|
||||
let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
|
||||
assert_eq!(rig_document.is_ok(), false);
|
||||
assert_eq!(
|
||||
rig_document.err().unwrap().to_string(),
|
||||
CompletionError::ProviderError("Unsupported media type xlsx".into()).to_string()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
use aws_sdk_bedrockruntime::config::http::HttpResponse;
|
||||
use aws_sdk_bedrockruntime::error::SdkError;
|
||||
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
|
||||
use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
|
||||
use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelError;
|
||||
use rig::completion::CompletionError;
|
||||
use rig::embeddings::EmbeddingError;
|
||||
use rig::image_generation::ImageGenerationError;
|
||||
|
||||
pub struct AwsSdkInvokeModelError(pub SdkError<InvokeModelError, HttpResponse>);
|
||||
|
||||
impl AwsSdkInvokeModelError {
|
||||
pub fn into_service_error(self) -> String {
|
||||
let error: String = match self.0.into_service_error() {
|
||||
InvokeModelError::ModelTimeoutException(e) => e.message.unwrap_or("The request took too long to process. Processing time exceeded the model timeout length.".into()),
|
||||
InvokeModelError::AccessDeniedException(e) => e.message.unwrap_or("The request is denied because you do not have sufficient permissions to perform the requested action.".into()),
|
||||
InvokeModelError::ResourceNotFoundException(e) => e.message.unwrap_or("The specified resource ARN was not found.".into()),
|
||||
InvokeModelError::ThrottlingException(e) => e.message.unwrap_or("Your request was denied due to exceeding the account quotas for Amazon Bedrock.".into()),
|
||||
InvokeModelError::ServiceUnavailableException(e) => e.message.unwrap_or("The service isn't currently available.".into()),
|
||||
InvokeModelError::InternalServerException(e) => e.message.unwrap_or("An internal server error occurred.".into()),
|
||||
InvokeModelError::ValidationException(e) => e.message.unwrap_or("The input fails to satisfy the constraints specified by Amazon Bedrock.".into()),
|
||||
InvokeModelError::ModelNotReadyException(e) => e.message.unwrap_or("The model specified in the request is not ready to serve inference requests. The AWS SDK will automatically retry the operation up to 5 times.".into()),
|
||||
InvokeModelError::ModelErrorException(e) => e.message.unwrap_or("The request failed due to an error while processing the model.".into()),
|
||||
InvokeModelError::ServiceQuotaExceededException(e) => e.message.unwrap_or("Your request exceeds the service quota for your account.".into()),
|
||||
_ => "An unexpected error occurred. Verify Internet connection or AWS keys".into(),
|
||||
};
|
||||
error
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AwsSdkInvokeModelError> for ImageGenerationError {
|
||||
fn from(value: AwsSdkInvokeModelError) -> Self {
|
||||
ImageGenerationError::ProviderError(value.into_service_error())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AwsSdkInvokeModelError> for EmbeddingError {
|
||||
fn from(value: AwsSdkInvokeModelError) -> Self {
|
||||
EmbeddingError::ProviderError(value.into_service_error())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AwsSdkConverseError(pub SdkError<ConverseError, HttpResponse>);
|
||||
|
||||
impl From<AwsSdkConverseError> for CompletionError {
|
||||
fn from(value: AwsSdkConverseError) -> Self {
|
||||
let error: String = match value.0.into_service_error() {
|
||||
ConverseError::ModelTimeoutException(e) => e.message.unwrap_or("The request took too long to process. Processing time exceeded the model timeout length.".into()),
|
||||
ConverseError::AccessDeniedException(e) => e.message.unwrap_or("The request is denied because you do not have sufficient permissions to perform the requested action.".into()),
|
||||
ConverseError::ResourceNotFoundException(e) => e.message.unwrap_or("The specified resource ARN was not found.".into()),
|
||||
ConverseError::ThrottlingException(e) => e.message.unwrap_or("Your request was denied due to exceeding the account quotas for AWS Bedrock.".into()),
|
||||
ConverseError::ServiceUnavailableException(e) => e.message.unwrap_or("The service isn't currently available.".into()),
|
||||
ConverseError::InternalServerException(e) => e.message.unwrap_or("An internal server error occurred.".into()),
|
||||
ConverseError::ValidationException(e) => e.message.unwrap_or("The input fails to satisfy the constraints specified by AWS Bedrock.".into()),
|
||||
ConverseError::ModelNotReadyException(e) => e.message.unwrap_or("The model specified in the request is not ready to serve inference requests. The AWS SDK will automatically retry the operation up to 5 times.".into()),
|
||||
ConverseError::ModelErrorException(e) => e.message.unwrap_or("The request failed due to an error while processing the model.".into()),
|
||||
_ => String::from("An unexpected error occurred. Verify Internet connection or AWS keys")
|
||||
};
|
||||
CompletionError::ProviderError(error)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AwsSdkConverseStreamError(pub SdkError<ConverseStreamError, HttpResponse>);
|
||||
impl From<AwsSdkConverseStreamError> for CompletionError {
|
||||
fn from(value: AwsSdkConverseStreamError) -> Self {
|
||||
let error: String = match value.0.into_service_error() {
|
||||
ConverseStreamError::ModelTimeoutException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::AccessDeniedException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ResourceNotFoundException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ThrottlingException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ServiceUnavailableException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::InternalServerException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ModelStreamErrorException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ValidationException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ModelNotReadyException(e) => e.message.unwrap(),
|
||||
ConverseStreamError::ModelErrorException(e) => e.message.unwrap(),
|
||||
_ => "An unexpected error occurred. Verify Internet connection or AWS keys".into(),
|
||||
};
|
||||
CompletionError::ProviderError(error)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ContentFormat, Image, ImageMediaType, MimeType},
|
||||
};
|
||||
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RigImage(pub Image);
|
||||
|
||||
impl TryFrom<RigImage> for aws_bedrock::ImageBlock {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(image: RigImage) -> Result<Self, Self::Error> {
|
||||
let maybe_format: Option<Result<aws_bedrock::ImageFormat, CompletionError>> =
|
||||
image.0.media_type.map(|f| match f {
|
||||
ImageMediaType::JPEG => Ok(aws_bedrock::ImageFormat::Jpeg),
|
||||
ImageMediaType::PNG => Ok(aws_bedrock::ImageFormat::Png),
|
||||
ImageMediaType::GIF => Ok(aws_bedrock::ImageFormat::Gif),
|
||||
ImageMediaType::WEBP => Ok(aws_bedrock::ImageFormat::Webp),
|
||||
e => Err(CompletionError::ProviderError(format!(
|
||||
"Unsupported format {}",
|
||||
e.to_mime_type()
|
||||
))),
|
||||
});
|
||||
|
||||
let format = match maybe_format {
|
||||
Some(Ok(image_format)) => Ok(Some(image_format)),
|
||||
Some(Err(err)) => Err(err),
|
||||
None => Ok(None),
|
||||
}?;
|
||||
|
||||
let img_data = BASE64_STANDARD
|
||||
.decode(image.0.data)
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||
let blob = aws_smithy_types::Blob::new(img_data);
|
||||
let result = aws_bedrock::ImageBlock::builder()
|
||||
.set_format(format)
|
||||
.source(aws_bedrock::ImageSource::Bytes(blob))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<aws_bedrock::ImageBlock> for RigImage {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(image: aws_bedrock::ImageBlock) -> Result<Self, Self::Error> {
|
||||
let media_type = match image.format {
|
||||
aws_bedrock::ImageFormat::Gif => Ok(ImageMediaType::GIF),
|
||||
aws_bedrock::ImageFormat::Jpeg => Ok(ImageMediaType::JPEG),
|
||||
aws_bedrock::ImageFormat::Png => Ok(ImageMediaType::PNG),
|
||||
aws_bedrock::ImageFormat::Webp => Ok(ImageMediaType::WEBP),
|
||||
e => Err(CompletionError::ProviderError(format!(
|
||||
"Unsupported format {}",
|
||||
e
|
||||
))),
|
||||
}?;
|
||||
|
||||
let data = match image.source {
|
||||
Some(aws_bedrock::ImageSource::Bytes(blob)) => {
|
||||
let encoded_img = BASE64_STANDARD.encode(blob.into_inner());
|
||||
Ok(encoded_img)
|
||||
}
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"Image source is missing".into(),
|
||||
)),
|
||||
}?;
|
||||
Ok(RigImage(Image {
|
||||
data,
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(media_type),
|
||||
detail: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ContentFormat, Image, ImageMediaType},
|
||||
};
|
||||
|
||||
use crate::types::image::RigImage;
|
||||
|
||||
#[test]
|
||||
fn test_image_to_aws_image() {
|
||||
let rig_image = RigImage(Image {
|
||||
data: BASE64_STANDARD.encode("img_data"),
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(ImageMediaType::JPEG),
|
||||
detail: None,
|
||||
});
|
||||
let aws_image: Result<aws_bedrock::ImageBlock, _> = rig_image.clone().try_into();
|
||||
assert_eq!(aws_image.is_ok(), true);
|
||||
let aws_image = aws_image.unwrap();
|
||||
assert_eq!(aws_image.format, aws_bedrock::ImageFormat::Jpeg);
|
||||
let img_data = BASE64_STANDARD.decode(rig_image.0.data).unwrap();
|
||||
let aws_image_bytes = aws_image
|
||||
.source()
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.to_owned();
|
||||
assert_eq!(aws_image_bytes, img_data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_image_to_aws_image() {
|
||||
let rig_image = RigImage(Image {
|
||||
data: BASE64_STANDARD.encode("img_data"),
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(ImageMediaType::HEIC),
|
||||
detail: None,
|
||||
});
|
||||
let aws_image: Result<aws_bedrock::ImageBlock, _> = rig_image.clone().try_into();
|
||||
assert_eq!(
|
||||
aws_image.err().unwrap().to_string(),
|
||||
CompletionError::ProviderError("Unsupported format image/heic".into()).to_string()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
use aws_smithy_types::{Document, Number};
|
||||
use serde_json::{Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AwsDocument(pub Document);
|
||||
|
||||
impl From<AwsDocument> for Value {
|
||||
fn from(value: AwsDocument) -> Self {
|
||||
match value.0 {
|
||||
Document::Object(obj) => {
|
||||
let documents = obj
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, AwsDocument(v).into()))
|
||||
.collect::<Map<_, _>>();
|
||||
Value::Object(documents)
|
||||
}
|
||||
Document::Array(arr) => {
|
||||
let documents = arr.into_iter().map(|v| AwsDocument(v).into()).collect();
|
||||
Value::Array(documents)
|
||||
}
|
||||
Document::Number(Number::PosInt(number)) => {
|
||||
Value::Number(serde_json::Number::from(number))
|
||||
}
|
||||
Document::Number(Number::NegInt(number)) => {
|
||||
Value::Number(serde_json::Number::from(number))
|
||||
}
|
||||
Document::Number(Number::Float(number)) => match serde_json::Number::from_f64(number) {
|
||||
Some(n) => Value::Number(n),
|
||||
// https://www.rfc-editor.org/rfc/rfc7159
|
||||
// Numeric values that cannot be represented in the grammar (such as Infinity and NaN) are not permitted.
|
||||
None => Value::Null,
|
||||
},
|
||||
Document::String(s) => Value::String(s),
|
||||
Document::Bool(b) => Value::Bool(b),
|
||||
Document::Null => Value::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Value> for AwsDocument {
|
||||
fn from(value: Value) -> Self {
|
||||
match value {
|
||||
Value::Null => AwsDocument(Document::Null),
|
||||
Value::Bool(b) => AwsDocument(Document::Bool(b)),
|
||||
Value::Number(num) => {
|
||||
if let Some(i) = num.as_i64() {
|
||||
match i > 0 {
|
||||
true => AwsDocument(Document::Number(Number::PosInt(i as u64))),
|
||||
false => AwsDocument(Document::Number(Number::NegInt(i))),
|
||||
}
|
||||
} else if let Some(f) = num.as_f64() {
|
||||
AwsDocument(Document::Number(Number::Float(f)))
|
||||
} else {
|
||||
AwsDocument(Document::Null)
|
||||
}
|
||||
}
|
||||
Value::String(s) => AwsDocument(Document::String(s)),
|
||||
Value::Array(arr) => {
|
||||
let documents = arr
|
||||
.into_iter()
|
||||
.map(|json| json.into())
|
||||
.map(|aws: AwsDocument| aws.0)
|
||||
.collect();
|
||||
AwsDocument(Document::Array(documents))
|
||||
}
|
||||
Value::Object(obj) => {
|
||||
let documents = obj
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
let doc: AwsDocument = v.into();
|
||||
(k, doc.0)
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
AwsDocument(Document::Object(documents))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use aws_smithy_types::{Document, Number};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::types::json::AwsDocument;
|
||||
|
||||
#[test]
|
||||
fn test_json_to_aws_document() {
|
||||
let json = r#"
|
||||
{
|
||||
"type": "object",
|
||||
"is_enabled": true,
|
||||
"version": 42,
|
||||
"fraction": 1.23,
|
||||
"negative": -11,
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
},
|
||||
"required":["x", "y", null]
|
||||
}
|
||||
"#;
|
||||
|
||||
let value: Value = serde_json::from_str(json).unwrap();
|
||||
let document: AwsDocument = value.into();
|
||||
println!("{:?}", document);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aws_document_to_json() {
|
||||
let document = AwsDocument(Document::Object(HashMap::from([
|
||||
(
|
||||
String::from("type"),
|
||||
Document::String(String::from("object")),
|
||||
),
|
||||
(
|
||||
String::from("version"),
|
||||
Document::Number(Number::PosInt(42)),
|
||||
),
|
||||
(
|
||||
String::from("fraction"),
|
||||
Document::Number(Number::Float(1.23)),
|
||||
),
|
||||
(
|
||||
String::from("negative"),
|
||||
Document::Number(Number::NegInt(-11)),
|
||||
),
|
||||
(String::from("is_enabled"), Document::Bool(true)),
|
||||
(
|
||||
String::from("properties"),
|
||||
Document::Object(HashMap::from([
|
||||
(
|
||||
String::from("x"),
|
||||
Document::Object(HashMap::from([
|
||||
(
|
||||
String::from("type"),
|
||||
Document::String(String::from("number")),
|
||||
),
|
||||
(
|
||||
String::from("description"),
|
||||
Document::String(String::from("The first number to add")),
|
||||
),
|
||||
])),
|
||||
),
|
||||
(
|
||||
String::from("y"),
|
||||
Document::Object(HashMap::from([
|
||||
(
|
||||
String::from("type"),
|
||||
Document::String(String::from("number")),
|
||||
),
|
||||
(
|
||||
String::from("description"),
|
||||
Document::String(String::from("The second number to add")),
|
||||
),
|
||||
])),
|
||||
),
|
||||
])),
|
||||
),
|
||||
(
|
||||
String::from("required"),
|
||||
Document::Array(vec![
|
||||
Document::String(String::from("x")),
|
||||
Document::String(String::from("y")),
|
||||
Document::Null,
|
||||
]),
|
||||
),
|
||||
])));
|
||||
|
||||
let json: Value = document.into();
|
||||
println!("{:?}", json);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
use aws_sdk_bedrockruntime::types::DocumentFormat;
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{DocumentMediaType, MimeType},
|
||||
};
|
||||
|
||||
pub struct RigDocumentMediaType(pub DocumentMediaType);
|
||||
|
||||
impl TryFrom<RigDocumentMediaType> for DocumentFormat {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigDocumentMediaType) -> Result<Self, Self::Error> {
|
||||
match value.0 {
|
||||
DocumentMediaType::PDF => Ok(DocumentFormat::Pdf),
|
||||
DocumentMediaType::TXT => Ok(DocumentFormat::Txt),
|
||||
DocumentMediaType::HTML => Ok(DocumentFormat::Html),
|
||||
DocumentMediaType::MARKDOWN => Ok(DocumentFormat::Md),
|
||||
DocumentMediaType::CSV => Ok(DocumentFormat::Csv),
|
||||
e => Err(CompletionError::ProviderError(format!(
|
||||
"Unsupported media type {}",
|
||||
e.to_mime_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<DocumentFormat> for RigDocumentMediaType {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: DocumentFormat) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
DocumentFormat::Csv => Ok(RigDocumentMediaType(DocumentMediaType::CSV)),
|
||||
DocumentFormat::Html => Ok(RigDocumentMediaType(DocumentMediaType::HTML)),
|
||||
DocumentFormat::Md => Ok(RigDocumentMediaType(DocumentMediaType::MARKDOWN)),
|
||||
DocumentFormat::Pdf => Ok(RigDocumentMediaType(DocumentMediaType::PDF)),
|
||||
DocumentFormat::Txt => Ok(RigDocumentMediaType(DocumentMediaType::TXT)),
|
||||
e => Err(CompletionError::ProviderError(format!(
|
||||
"Unsupported media type {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{AssistantContent, Message, UserContent},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
|
||||
|
||||
pub struct RigMessage(pub Message);
|
||||
|
||||
impl TryFrom<RigMessage> for aws_bedrock::Message {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
|
||||
let result = match value.0 {
|
||||
Message::User { content } => {
|
||||
let message_content = content
|
||||
.into_iter()
|
||||
.map(|user_content| RigUserContent(user_content).try_into())
|
||||
.collect::<Result<Vec<Vec<_>>, _>>()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))
|
||||
.map(|nested| nested.into_iter().flatten().collect())?;
|
||||
|
||||
aws_bedrock::Message::builder()
|
||||
.role(aws_bedrock::ConversationRole::User)
|
||||
.set_content(Some(message_content))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?
|
||||
}
|
||||
Message::Assistant { content } => aws_bedrock::Message::builder()
|
||||
.role(aws_bedrock::ConversationRole::Assistant)
|
||||
.set_content(Some(
|
||||
content
|
||||
.into_iter()
|
||||
.map(|content| RigAssistantContent(content).try_into())
|
||||
.collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
|
||||
))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<aws_bedrock::Message> for RigMessage {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
|
||||
match message.role {
|
||||
aws_bedrock::ConversationRole::Assistant => {
|
||||
let assistant_content = message
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|c| c.try_into())
|
||||
.collect::<Result<Vec<RigAssistantContent>, _>>()?
|
||||
.into_iter()
|
||||
.map(|rig_assistant_content| rig_assistant_content.0)
|
||||
.collect::<Vec<AssistantContent>>();
|
||||
|
||||
let content = OneOrMany::many(assistant_content)
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
Ok(RigMessage(Message::Assistant { content }))
|
||||
}
|
||||
aws_bedrock::ConversationRole::User => {
|
||||
let user_content = message
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|c| c.try_into())
|
||||
.collect::<Result<Vec<RigUserContent>, _>>()?
|
||||
.into_iter()
|
||||
.map(|user_content| user_content.0)
|
||||
.collect::<Vec<UserContent>>();
|
||||
|
||||
let content = OneOrMany::many(user_content)
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
Ok(RigMessage(Message::User { content }))
|
||||
}
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"AWS Bedrock returned unsupported ConversationRole".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::types::message::RigMessage;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use rig::{
|
||||
message::{Message, UserContent},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn message_to_aws_message() {
|
||||
let message = Message::User {
|
||||
content: OneOrMany::one(UserContent::Text("text".into())),
|
||||
};
|
||||
let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
|
||||
assert_eq!(aws_message.is_ok(), true);
|
||||
let aws_message = aws_message.unwrap();
|
||||
assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
|
||||
assert_eq!(
|
||||
aws_message.content,
|
||||
vec![aws_bedrock::ContentBlock::Text("text".into())]
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
pub(crate) mod assistant_content;
|
||||
pub(crate) mod completion_request;
|
||||
pub(crate) mod document;
|
||||
pub(crate) mod errors;
|
||||
pub(crate) mod image;
|
||||
pub(crate) mod json;
|
||||
pub(crate) mod media_types;
|
||||
pub(crate) mod message;
|
||||
pub(crate) mod text_to_image;
|
||||
pub(crate) mod tool;
|
||||
pub(crate) mod user_content;
|
|
@ -0,0 +1,129 @@
|
|||
use base64::prelude::BASE64_STANDARD;
|
||||
use base64::Engine;
|
||||
use rig::image_generation;
|
||||
use rig::image_generation::ImageGenerationError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ImageQuality {
|
||||
Standard,
|
||||
Premium,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ImageGenerationConfig {
|
||||
// The quality of the image.
|
||||
// Default: standard
|
||||
pub quality: Option<ImageQuality>,
|
||||
// The number of images to generate.
|
||||
// Default: 1, Minimum: 1, Maximum: 5
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub number_of_images: Option<u32>,
|
||||
// The height of the image in pixels.
|
||||
pub height: Option<u32>,
|
||||
// The width of the image in pixels.
|
||||
pub width: Option<u32>,
|
||||
// Specifies how strongly the generated image should adhere to the prompt. Use a lower value to introduce more randomness in the generation.
|
||||
// Default: 8.0. Minimum: 1.1, Maximum: 10.0
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cfg_scale: Option<f32>,
|
||||
// Use to control and reproduce results. Determines the initial noise setting.
|
||||
// Use the same seed and the same settings as a previous run to allow inference to create a similar image.
|
||||
// Default: 42, Minimum: 0, Maximum: 2147483646
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for ImageGenerationConfig {
|
||||
fn default() -> Self {
|
||||
ImageGenerationConfig {
|
||||
quality: Some(ImageQuality::Standard),
|
||||
number_of_images: Some(1),
|
||||
height: Some(512),
|
||||
width: Some(512),
|
||||
cfg_scale: None,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextToImageParams {
|
||||
pub text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub negative_text: Option<String>,
|
||||
}
|
||||
|
||||
impl TextToImageParams {
|
||||
pub fn new(text: String) -> Self {
|
||||
Self {
|
||||
text,
|
||||
negative_text: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextToImageGeneration {
|
||||
pub task_type: &'static str,
|
||||
pub text_to_image_params: TextToImageParams,
|
||||
pub image_generation_config: ImageGenerationConfig,
|
||||
}
|
||||
|
||||
impl TextToImageGeneration {
|
||||
pub(crate) fn new(text: String) -> TextToImageGeneration {
|
||||
TextToImageGeneration {
|
||||
task_type: "TEXT_IMAGE",
|
||||
text_to_image_params: TextToImageParams::new(text),
|
||||
image_generation_config: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn height(&mut self, height: u32) -> &Self {
|
||||
self.image_generation_config.height = Some(height);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn width(&mut self, width: u32) -> &Self {
|
||||
self.image_generation_config.width = Some(width);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextToImageResponse {
|
||||
pub images: Option<Vec<String>>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl TryFrom<TextToImageResponse>
|
||||
for image_generation::ImageGenerationResponse<TextToImageResponse>
|
||||
{
|
||||
type Error = ImageGenerationError;
|
||||
|
||||
fn try_from(value: TextToImageResponse) -> Result<Self, Self::Error> {
|
||||
if let Some(error) = value.error {
|
||||
return Err(ImageGenerationError::ResponseError(error));
|
||||
}
|
||||
|
||||
if let Some(images) = value.to_owned().images {
|
||||
let data = BASE64_STANDARD
|
||||
.decode(&images[0])
|
||||
.expect("Could not decode image.");
|
||||
|
||||
return Ok(Self {
|
||||
image: data,
|
||||
response: value,
|
||||
});
|
||||
}
|
||||
|
||||
Err(ImageGenerationError::ResponseError(
|
||||
"Malformed response from model".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{Text, ToolResultContent},
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{image::RigImage, json::AwsDocument};
|
||||
|
||||
pub struct RigToolResultContent(pub ToolResultContent);
|
||||
|
||||
impl TryFrom<RigToolResultContent> for aws_bedrock::ToolResultContentBlock {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigToolResultContent) -> Result<Self, Self::Error> {
|
||||
match value.0 {
|
||||
ToolResultContent::Text(text) => {
|
||||
Ok(aws_bedrock::ToolResultContentBlock::Text(text.text))
|
||||
}
|
||||
ToolResultContent::Image(image) => {
|
||||
let image = RigImage(image).try_into()?;
|
||||
Ok(aws_bedrock::ToolResultContentBlock::Image(image))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<aws_bedrock::ToolResultContentBlock> for RigToolResultContent {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: aws_bedrock::ToolResultContentBlock) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
aws_bedrock::ToolResultContentBlock::Image(image) => {
|
||||
let image: RigImage = image.try_into()?;
|
||||
Ok(RigToolResultContent(ToolResultContent::Image(image.0)))
|
||||
}
|
||||
aws_bedrock::ToolResultContentBlock::Json(document) => {
|
||||
let json: Value = AwsDocument(document).into();
|
||||
Ok(RigToolResultContent(ToolResultContent::Text(Text {
|
||||
text: json.to_string(),
|
||||
})))
|
||||
}
|
||||
aws_bedrock::ToolResultContentBlock::Text(text) => {
|
||||
Ok(RigToolResultContent(ToolResultContent::Text(Text { text })))
|
||||
}
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"ToolResultContentBlock contains unsupported variant".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ContentFormat, Image, ImageMediaType, Text, ToolResultContent},
|
||||
};
|
||||
|
||||
use crate::types::tool::RigToolResultContent;
|
||||
|
||||
#[test]
|
||||
fn rig_tool_text_to_aws_tool() {
|
||||
let tool = RigToolResultContent(ToolResultContent::Text(Text { text: "42".into() }));
|
||||
let aws_tool: Result<aws_bedrock::ToolResultContentBlock, _> = tool.try_into();
|
||||
assert_eq!(aws_tool.is_ok(), true);
|
||||
assert_eq!(
|
||||
String::from(aws_tool.unwrap().as_text().unwrap()),
|
||||
String::from("42")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rig_tool_image_to_aws_tool() {
|
||||
let image = Image {
|
||||
data: BASE64_STANDARD.encode("img_data"),
|
||||
format: Some(ContentFormat::Base64),
|
||||
media_type: Some(ImageMediaType::JPEG),
|
||||
detail: None,
|
||||
};
|
||||
let tool = RigToolResultContent(ToolResultContent::Image(image));
|
||||
let aws_tool: Result<aws_bedrock::ToolResultContentBlock, _> = tool.try_into();
|
||||
assert_eq!(aws_tool.is_ok(), true);
|
||||
assert_eq!(aws_tool.unwrap().is_image(), true)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aws_tool_to_rig_tool() {
|
||||
let aws_tool = aws_bedrock::ToolResultContentBlock::Text("txt".into());
|
||||
let tool: Result<RigToolResultContent, _> = aws_tool.try_into();
|
||||
assert_eq!(tool.is_ok(), true);
|
||||
let tool = match tool.unwrap().0 {
|
||||
ToolResultContent::Text(text) => Ok(text),
|
||||
_ => Err("tool doesn't contain text"),
|
||||
};
|
||||
assert_eq!(tool.is_ok(), true);
|
||||
assert_eq!(tool.unwrap().text, String::from("txt"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aws_tool_to_unsupported_rig_tool() {
|
||||
let document_source =
|
||||
aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new("document_data"));
|
||||
let aws_document = aws_bedrock::DocumentBlock::builder()
|
||||
.format(aws_bedrock::DocumentFormat::Pdf)
|
||||
.name("Document")
|
||||
.source(document_source)
|
||||
.build()
|
||||
.unwrap();
|
||||
let aws_tool = aws_bedrock::ToolResultContentBlock::Document(aws_document);
|
||||
let tool: Result<RigToolResultContent, _> = aws_tool.try_into();
|
||||
assert_eq!(tool.is_ok(), false);
|
||||
assert_eq!(
|
||||
tool.err().unwrap().to_string(),
|
||||
CompletionError::ProviderError(
|
||||
"ToolResultContentBlock contains unsupported variant".into()
|
||||
)
|
||||
.to_string()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{Text, ToolResult, ToolResultContent, UserContent},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use super::{document::RigDocument, image::RigImage, tool::RigToolResultContent};
|
||||
|
||||
pub struct RigUserContent(pub UserContent);
|
||||
|
||||
impl TryFrom<aws_bedrock::ContentBlock> for RigUserContent {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: aws_bedrock::ContentBlock) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
aws_bedrock::ContentBlock::Text(text) => {
|
||||
Ok(RigUserContent(UserContent::Text(Text { text })))
|
||||
}
|
||||
aws_bedrock::ContentBlock::ToolResult(tool_result) => {
|
||||
let tool_result_contents = tool_result
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|tool| tool.try_into())
|
||||
.collect::<Result<Vec<RigToolResultContent>, _>>()?
|
||||
.into_iter()
|
||||
.map(|rt| rt.0)
|
||||
.collect::<Vec<ToolResultContent>>();
|
||||
|
||||
let tool_results = OneOrMany::many(tool_result_contents).map_err(|_| {
|
||||
CompletionError::ProviderError("ToolResult returned invalid response".into())
|
||||
})?;
|
||||
Ok(RigUserContent(UserContent::ToolResult(ToolResult {
|
||||
id: tool_result.tool_use_id,
|
||||
content: tool_results,
|
||||
})))
|
||||
}
|
||||
aws_bedrock::ContentBlock::Document(document) => {
|
||||
let doc: RigDocument = document.try_into()?;
|
||||
Ok(RigUserContent(UserContent::Document(doc.0)))
|
||||
}
|
||||
aws_bedrock::ContentBlock::Image(image) => {
|
||||
let image: RigImage = image.try_into()?;
|
||||
Ok(RigUserContent(UserContent::Image(image.0)))
|
||||
}
|
||||
_ => Err(CompletionError::ProviderError(
|
||||
"ToolResultContentBlock contains unsupported variant".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RigUserContent> for Vec<aws_bedrock::ContentBlock> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: RigUserContent) -> Result<Self, Self::Error> {
|
||||
match value.0 {
|
||||
UserContent::Text(text) => Ok(vec![aws_bedrock::ContentBlock::Text(text.text)]),
|
||||
UserContent::ToolResult(tool_result) => {
|
||||
let builder = aws_bedrock::ToolResultBlock::builder()
|
||||
.tool_use_id(tool_result.id)
|
||||
.set_content(Some(
|
||||
tool_result
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|tool| RigToolResultContent(tool).try_into())
|
||||
.collect::<Result<Vec<aws_bedrock::ToolResultContentBlock>, _>>()?,
|
||||
))
|
||||
.build()
|
||||
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
|
||||
Ok(vec![aws_bedrock::ContentBlock::ToolResult(builder)])
|
||||
}
|
||||
UserContent::Image(image) => {
|
||||
let image = RigImage(image).try_into()?;
|
||||
Ok(vec![aws_bedrock::ContentBlock::Image(image)])
|
||||
}
|
||||
UserContent::Document(document) => {
|
||||
let doc = RigDocument(document).try_into()?;
|
||||
// AWS documentations: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html
|
||||
// In the content field of the Message object, you must also include a text field with a prompt related to the document.
|
||||
Ok(vec![
|
||||
aws_bedrock::ContentBlock::Text("Use provided document".to_string()),
|
||||
aws_bedrock::ContentBlock::Document(doc),
|
||||
])
|
||||
}
|
||||
UserContent::Audio(_) => Err(CompletionError::ProviderError(
|
||||
"Audio is not supported".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::types::user_content::RigUserContent;
|
||||
use aws_sdk_bedrockruntime::types as aws_bedrock;
|
||||
use rig::{
|
||||
completion::CompletionError,
|
||||
message::{ToolResultContent, UserContent},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn aws_content_block_to_user_content() {
|
||||
let cb = aws_bedrock::ContentBlock::Text("42".into());
|
||||
let user_content: Result<RigUserContent, _> = cb.try_into();
|
||||
assert_eq!(user_content.is_ok(), true);
|
||||
let content = match user_content.unwrap().0 {
|
||||
rig::message::UserContent::Text(text) => Ok(text),
|
||||
_ => Err("Invalid content type"),
|
||||
};
|
||||
assert_eq!(content.is_ok(), true);
|
||||
assert_eq!(content.unwrap().text, "42")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aws_content_block_tool_to_user_content() {
|
||||
let cb = aws_bedrock::ContentBlock::ToolResult(
|
||||
aws_bedrock::ToolResultBlock::builder()
|
||||
.tool_use_id("123")
|
||||
.content(aws_bedrock::ToolResultContentBlock::Text("content".into()))
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let user_content: Result<RigUserContent, _> = cb.try_into();
|
||||
assert_eq!(user_content.is_ok(), true);
|
||||
let content = match user_content.unwrap().0 {
|
||||
rig::message::UserContent::ToolResult(tool_result) => Ok(tool_result),
|
||||
_ => Err("Invalid content type"),
|
||||
};
|
||||
assert_eq!(content.is_ok(), true);
|
||||
let content = content.unwrap();
|
||||
assert_eq!(content.id, "123");
|
||||
assert_eq!(
|
||||
content.content,
|
||||
OneOrMany::one(ToolResultContent::Text("content".into()))
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aws_unsupported_content_block_to_user_content() {
|
||||
let cb = aws_bedrock::ContentBlock::GuardContent(
|
||||
aws_bedrock::GuardrailConverseContentBlock::Text(
|
||||
aws_bedrock::GuardrailConverseTextBlock::builder()
|
||||
.text("stuff")
|
||||
.build()
|
||||
.unwrap(),
|
||||
),
|
||||
);
|
||||
let user_content: Result<RigUserContent, _> = cb.try_into();
|
||||
assert_eq!(user_content.is_ok(), false);
|
||||
assert_eq!(
|
||||
user_content.err().unwrap().to_string(),
|
||||
CompletionError::ProviderError(
|
||||
"ToolResultContentBlock contains unsupported variant".into()
|
||||
)
|
||||
.to_string()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_content_to_aws_content_block() {
|
||||
let uc = RigUserContent(UserContent::Text("txt".into()));
|
||||
let aws_content_blocks: Result<Vec<aws_bedrock::ContentBlock>, _> = uc.try_into();
|
||||
assert_eq!(aws_content_blocks.is_ok(), true);
|
||||
let aws_content_blocks = aws_content_blocks.unwrap();
|
||||
assert_eq!(
|
||||
aws_content_blocks,
|
||||
vec![aws_bedrock::ContentBlock::Text("txt".into())]
|
||||
);
|
||||
}
|
||||
}
|
|
@ -41,6 +41,7 @@ pub trait ImageGeneration<M: ImageGenerationModel> {
|
|||
> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ImageGenerationResponse<T> {
|
||||
pub image: Vec<u8>,
|
||||
pub response: T,
|
||||
|
|
|
@ -151,7 +151,7 @@ pub enum Message {
|
|||
)]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
#[serde(rename = "Tool")]
|
||||
#[serde(rename = "tool")]
|
||||
ToolResult {
|
||||
tool_call_id: String,
|
||||
content: String,
|
||||
|
|
|
@ -102,7 +102,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
|||
.tools
|
||||
.call(&name, params.to_string())
|
||||
.await
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
println!("\nResult: {}", res);
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
|
@ -277,7 +277,7 @@ impl Neo4jClient {
|
|||
/// ### Arguments
|
||||
/// * `index_name` - The name of the index to create.
|
||||
/// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
|
||||
/// the label `:Movie`, pass "Movie" as the `node_label` parameter.
|
||||
/// the label `:Movie`, pass "Movie" as the `node_label` parameter.
|
||||
/// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
|
||||
///
|
||||
pub async fn create_vector_index(
|
||||
|
|
|
@ -211,7 +211,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for Neo4jVec
|
|||
/// #### Generic Type Parameters
|
||||
///
|
||||
/// - `T`: The type used to deserialize the result from the Neo4j query.
|
||||
/// It must implement the `serde::Deserialize` trait.
|
||||
/// It must implement the `serde::Deserialize` trait.
|
||||
///
|
||||
/// #### Returns
|
||||
///
|
||||
|
|
|
@ -29,7 +29,7 @@ impl<M: EmbeddingModel> QdrantVectorStore<M> {
|
|||
/// * `client` - Qdrant client instance
|
||||
/// * `model` - Embedding model instance
|
||||
/// * `query_params` - Search parameters for vector queries
|
||||
/// Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
|
||||
/// Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
|
||||
pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
|
||||
Self {
|
||||
client,
|
||||
|
|
Loading…
Reference in New Issue