diff --git a/rig-fastembed/Cargo.toml b/rig-fastembed/Cargo.toml index 2d7f22f..3d80b05 100644 --- a/rig-fastembed/Cargo.toml +++ b/rig-fastembed/Cargo.toml @@ -13,7 +13,7 @@ serde = { version = "1.0.193", features = ["derive"] } serde_json = "1.0.108" tracing = "0.1.40" schemars = "0.8.16" -fastembed = "4.4.0" +fastembed = "4.6.0" [dev-dependencies] anyhow = "1.0.75" @@ -26,3 +26,6 @@ required-features = ["rig-core/derive"] [[example]] name = "vector_search_local" required-features = ["rig-core/derive"] + +[[example]] +name = "embed_images" diff --git a/rig-fastembed/examples/embed_images.rs b/rig-fastembed/examples/embed_images.rs new file mode 100644 index 0000000..7308b2b --- /dev/null +++ b/rig-fastembed/examples/embed_images.rs @@ -0,0 +1,18 @@ +use rig::embeddings::embedding::ImageEmbeddingModel as _; +use rig_fastembed::FastembedImageModel; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let fastembed_client = rig_fastembed::Client::new(); + + let embedding_model = + fastembed_client.image_embedding_model(&FastembedImageModel::NomicEmbedVisionV15); + + let bytes = std::fs::read("image.png").unwrap(); + + let res = embedding_model.embed_image(&bytes).await.unwrap(); + + println!("{}", res.vec.len()); + + Ok(()) +} diff --git a/rig-fastembed/examples/vector_search.rs b/rig-fastembed/examples/vector_search.rs index fa7a3e7..b81d069 100644 --- a/rig-fastembed/examples/vector_search.rs +++ b/rig-fastembed/examples/vector_search.rs @@ -3,7 +3,7 @@ use rig::{ vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embed, }; -use rig_fastembed::FastembedModel; +use rig_fastembed::FastembedTextModel; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. @@ -18,10 +18,11 @@ struct WordDefinition { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - // Create OpenAI client + // Create fastembed client let fastembed_client = rig_fastembed::Client::new(); - let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); + let embedding_model = + fastembed_client.text_embedding_model(&FastembedTextModel::AllMiniLML6V2Q); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ diff --git a/rig-fastembed/examples/vector_search_local.rs b/rig-fastembed/examples/vector_search_local.rs index 5a0d433..db54208 100644 --- a/rig-fastembed/examples/vector_search_local.rs +++ b/rig-fastembed/examples/vector_search_local.rs @@ -7,7 +7,7 @@ use rig::{ vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embed, }; -use rig_fastembed::EmbeddingModel; +use rig_fastembed::TextEmbeddingModel; use serde::{Deserialize, Serialize}; use std::path::Path; @@ -51,7 +51,7 @@ async fn main() -> Result<(), anyhow::Error> { UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean); let embedding_model = - EmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info); + TextEmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info); // Create documents let documents = vec![ diff --git a/rig-fastembed/src/lib.rs b/rig-fastembed/src/lib.rs index 21a286c..0a0c2ed 100644 --- a/rig-fastembed/src/lib.rs +++ b/rig-fastembed/src/lib.rs @@ -1,8 +1,11 @@ use std::sync::Arc; -pub use fastembed::EmbeddingModel as FastembedModel; +pub use fastembed::EmbeddingModel as FastembedTextModel; +pub use fastembed::ImageEmbeddingModel as FastembedImageModel; use fastembed::{ - InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel, + ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, InitOptions, + InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel, + UserDefinedImageEmbeddingModel, }; use rig::{ embeddings::{self, EmbeddingError, EmbeddingsBuilder}, @@ -24,7 +27,7 @@ impl Client { Self } - /// Create an embedding model with the given name. + /// Create a text embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. /// If this is the case, it's better to use function `embedding_model_with_ndims` /// @@ -37,10 +40,28 @@ impl Client { /// /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); /// ``` - pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel { - let ndims = fetch_model_ndims(model); + pub fn text_embedding_model(&self, model: &FastembedTextModel) -> TextEmbeddingModel { + TextEmbeddingModel::new(model) + } - EmbeddingModel::new(model, ndims) + /// Create an image embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig_fastembed::{Client, FastembedModel}; + /// + /// // Initialize the OpenAI client + /// let fastembed_client = Client::new("your-open-ai-api-key"); + /// + /// let embedding_model = fastembed_client.image_embedding_model(&FastembedModel::AllMiniLML6V2Q); + /// ``` + pub fn image_embedding_model( + &self, + model: &fastembed::ImageEmbeddingModel, + ) -> ImageEmbeddingModel { + ImageEmbeddingModel::new(model) } /// Create an embedding builder with the given embedding model. @@ -59,23 +80,23 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( + pub fn text_embeddings( &self, - model: &fastembed::EmbeddingModel, - ) -> EmbeddingsBuilder { - EmbeddingsBuilder::new(self.embedding_model(model)) + model: &FastembedTextModel, + ) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.text_embedding_model(model)) } } #[derive(Clone)] -pub struct EmbeddingModel { +pub struct TextEmbeddingModel { embedder: Arc, - pub model: FastembedModel, + pub model: FastembedTextModel, ndims: usize, } -impl EmbeddingModel { - pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self { +impl TextEmbeddingModel { + pub fn new(model: &fastembed::EmbeddingModel) -> Self { let embedder = Arc::new( TextEmbedding::try_new( InitOptions::new(model.to_owned()).with_show_download_progress(true), @@ -83,6 +104,8 @@ impl EmbeddingModel { .unwrap(), ); + let ndims = TextEmbedding::get_model_info(model).unwrap().dim; + Self { embedder, model: model.to_owned(), @@ -93,7 +116,7 @@ impl EmbeddingModel { pub fn new_from_user_defined( user_defined_model: UserDefinedEmbeddingModel, ndims: usize, - model_info: &ModelInfo, + model_info: &ModelInfo, ) -> Self { let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined( user_defined_model, @@ -111,7 +134,7 @@ impl EmbeddingModel { } } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for TextEmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -142,35 +165,76 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } -/// As seen on the text embedding model cards file: -pub fn fetch_model_ndims(model: &FastembedModel) -> usize { - match model { - FastembedModel::AllMiniLML6V2 - | FastembedModel::AllMiniLML6V2Q - | FastembedModel::AllMiniLML12V2 - | FastembedModel::AllMiniLML12V2Q - | FastembedModel::BGESmallENV15 - | FastembedModel::BGESmallENV15Q - | FastembedModel::ParaphraseMLMiniLML12V2Q - | FastembedModel::ParaphraseMLMiniLML12V2 - | FastembedModel::MultilingualE5Small => 384, - FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512, - FastembedModel::BGEBaseENV15 - | FastembedModel::BGEBaseENV15Q - | FastembedModel::NomicEmbedTextV1 - | FastembedModel::NomicEmbedTextV15 - | FastembedModel::NomicEmbedTextV15Q - | FastembedModel::ParaphraseMLMpnetBaseV2 - | FastembedModel::MultilingualE5Base - | FastembedModel::GTEBaseENV15 - | FastembedModel::GTEBaseENV15Q - | FastembedModel::JinaEmbeddingsV2BaseCode => 768, - FastembedModel::BGELargeENV15 - | FastembedModel::BGELargeENV15Q - | FastembedModel::MultilingualE5Large - | FastembedModel::MxbaiEmbedLargeV1 - | FastembedModel::MxbaiEmbedLargeV1Q - | FastembedModel::GTELargeENV15 - | FastembedModel::GTELargeENV15Q => 1024, +#[derive(Clone)] +pub struct ImageEmbeddingModel { + embedder: Arc, + pub model: FastembedImageModel, + ndims: usize, +} + +impl ImageEmbeddingModel { + pub fn new(model: &FastembedImageModel) -> Self { + let embedder = Arc::new( + ImageEmbedding::try_new( + ImageInitOptions::new(model.to_owned()).with_show_download_progress(true), + ) + .unwrap(), + ); + + let ndims = ImageEmbedding::get_model_info(model).dim; + + Self { + embedder, + model: model.to_owned(), + ndims, + } + } + + pub fn new_from_user_defined( + user_defined_model: UserDefinedImageEmbeddingModel, + ndims: usize, + model_info: &ModelInfo, + ) -> Self { + let fastembed_embedding_model = ImageEmbedding::try_new_from_user_defined( + user_defined_model, + ImageInitOptionsUserDefined::default(), + ) + .unwrap(); + + let embedder = Arc::new(fastembed_embedding_model); + + Self { + embedder, + model: model_info.model.to_owned(), + ndims, + } + } +} + +impl rig::embeddings::embedding::ImageEmbeddingModel for ImageEmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + self.ndims + } + + async fn embed_images( + &self, + documents: impl IntoIterator>, + ) -> Result, EmbeddingError> { + let images: Vec> = documents.into_iter().collect(); + let images: Vec<&[u8]> = images.iter().map(|x| x.as_slice()).collect(); + + let images_as_vec = self.embedder.embed_bytes(&images, None).unwrap(); + + let docs = images_as_vec + .into_iter() + .map(|embedding| embeddings::Embedding { + document: String::new(), + vec: embedding.into_iter().map(|f| f as f64).collect(), + }) + .collect::>(); + + Ok(docs) } }