fix: AWS Bedrock documents

feat: Bedrock has requirement that each document needs unique name so I added fingerprint based on document content
fix: decode document content based on ContentFormat variant
This commit is contained in:
Marko Kranjac 2025-04-17 19:25:02 +02:00
parent 522d3f6ba1
commit 265497688f
3 changed files with 75 additions and 35 deletions

47
Cargo.lock generated
View File

@ -8622,7 +8622,7 @@ checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
[[package]]
name = "rig-bedrock"
version = "0.1.1"
version = "0.1.2"
dependencies = [
"anyhow",
"async-stream",
@ -8631,8 +8631,9 @@ dependencies = [
"aws-smithy-types",
"base64 0.22.1",
"reqwest 0.12.15",
"rig-core 0.11.0",
"rig-core 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rig-derive",
"ring 0.17.14",
"schemars",
"serde",
"serde_json",
@ -8641,27 +8642,6 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "rig-core"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff893305131b471009ab11df388612beb603ed94bb12412c256fe197b7591aa6"
dependencies = [
"async-stream",
"base64 0.22.1",
"bytes",
"futures",
"glob",
"mime_guess",
"ordered-float",
"reqwest 0.12.15",
"schemars",
"serde",
"serde_json",
"thiserror 1.0.69",
"tracing",
]
[[package]]
name = "rig-core"
version = "0.11.1"
@ -8695,6 +8675,27 @@ dependencies = [
"worker",
]
[[package]]
name = "rig-core"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb610bd7e61825e79ca79b7efcad93206256147e27cbf707dffd80b7622b5ca7"
dependencies = [
"async-stream",
"base64 0.22.1",
"bytes",
"futures",
"glob",
"mime_guess",
"ordered-float",
"reqwest 0.12.15",
"schemars",
"serde",
"serde_json",
"thiserror 1.0.69",
"tracing",
]
[[package]]
name = "rig-derive"
version = "0.1.1"

View File

@ -1,13 +1,13 @@
[package]
name = "rig-bedrock"
version = "0.1.1"
version = "0.1.2"
edition = "2021"
license = "MIT"
readme = "README.md"
description = "AWS Bedrock model provider for Rig integration."
[dependencies]
rig-core = { version = "0.11.0", features = ["image"] }
rig-core = { version = "0.11.1", features = ["image"] }
rig-derive = { path = "../rig-core/rig-core-derive", version = "0.1.1" }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
@ -18,6 +18,7 @@ aws-sdk-bedrockruntime = "1.77.0"
aws-smithy-types = "1.3.0"
base64 = "0.22.1"
async-stream = "0.3.6"
ring = "0.17.14"
[dev-dependencies]
anyhow = "1.0.75"

View File

@ -1,5 +1,4 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig::{
completion::CompletionError,
message::{ContentFormat, Document},
@ -7,35 +6,55 @@ use rig::{
pub(crate) use crate::types::media_types::RigDocumentMediaType;
use base64::{prelude::BASE64_STANDARD, Engine};
use ring::digest::{Context, Digest, SHA256};
#[derive(Clone)]
pub struct RigDocument(pub Document);
impl RigDocument {
pub fn fingerprint(&self) -> String {
let mut context = Context::new(&SHA256);
context.update(self.0.data.as_bytes());
let digest: Digest = context.finish();
digest
.as_ref()
.to_vec()
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>()
}
}
impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
type Error = CompletionError;
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
let maybe_format = value
let document_name = value.fingerprint();
let document_media_type = value
.0
.media_type
.map(|doc| RigDocumentMediaType(doc).try_into());
let format = match maybe_format {
let document_media_type = match document_media_type {
Some(Ok(document_format)) => Ok(Some(document_format)),
Some(Err(err)) => Err(err),
None => Ok(None),
}?;
let document_data = BASE64_STANDARD
let document_data = match value.0.format {
Some(ContentFormat::Base64) => BASE64_STANDARD
.decode(value.0.data)
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
.map_err(|e| CompletionError::ProviderError(e.to_string()))?,
_ => value.0.data.as_bytes().to_vec(),
};
let data = aws_smithy_types::Blob::new(document_data);
let document_source = aws_bedrock::DocumentSource::Bytes(data);
let result = aws_bedrock::DocumentBlock::builder()
.source(document_source)
.name("Document")
.set_format(format)
.name(document_name)
.set_format(document_media_type)
.build()
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
Ok(result)
@ -82,13 +101,32 @@ mod tests {
fn test_document_to_aws_document() {
let rig_document = RigDocument(Document {
data: "data".into(),
format: Some(ContentFormat::Base64),
format: Some(ContentFormat::String),
media_type: Some(DocumentMediaType::PDF),
});
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
assert_eq!(aws_document.is_ok(), true);
let aws_document = aws_document.unwrap();
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
let document_data = rig_document.0.data.as_bytes().to_vec();
let aws_document_bytes = aws_document
.source()
.unwrap()
.as_bytes()
.unwrap()
.as_ref()
.to_owned();
assert_eq!(aws_document_bytes, document_data)
}
#[test]
fn test_base64_document_to_aws_document() {
let rig_document = RigDocument(Document {
data: "data".into(),
format: Some(ContentFormat::Base64),
media_type: Some(DocumentMediaType::PDF),
});
let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap();
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
let aws_document_bytes = aws_document
.source()
@ -104,7 +142,7 @@ mod tests {
fn test_unsupported_document_to_aws_document() {
let rig_document = RigDocument(Document {
data: "data".into(),
format: Some(ContentFormat::Base64),
format: Some(ContentFormat::String),
media_type: Some(DocumentMediaType::Javascript),
});
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();