This commit is contained in:
Joshua Mo 2025-04-18 12:43:35 +01:00 committed by GitHub
commit b65c078261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 138 additions and 52 deletions

View File

@ -13,7 +13,7 @@ serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108" serde_json = "1.0.108"
tracing = "0.1.40" tracing = "0.1.40"
schemars = "0.8.16" schemars = "0.8.16"
fastembed = "4.4.0" fastembed = "4.6.0"
[dev-dependencies] [dev-dependencies]
anyhow = "1.0.75" anyhow = "1.0.75"
@ -26,3 +26,6 @@ required-features = ["rig-core/derive"]
[[example]] [[example]]
name = "vector_search_local" name = "vector_search_local"
required-features = ["rig-core/derive"] required-features = ["rig-core/derive"]
[[example]]
name = "embed_images"

View File

@ -0,0 +1,18 @@
use rig::embeddings::embedding::ImageEmbeddingModel as _;
use rig_fastembed::FastembedImageModel;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let fastembed_client = rig_fastembed::Client::new();
let embedding_model =
fastembed_client.image_embedding_model(&FastembedImageModel::NomicEmbedVisionV15);
let bytes = std::fs::read("image.png").unwrap();
let res = embedding_model.embed_image(&bytes).await.unwrap();
println!("{}", res.vec.len());
Ok(())
}

View File

@ -3,7 +3,7 @@ use rig::{
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed, Embed,
}; };
use rig_fastembed::FastembedModel; use rig_fastembed::FastembedTextModel;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// Shape of data that needs to be RAG'ed. // Shape of data that needs to be RAG'ed.
@ -18,10 +18,11 @@ struct WordDefinition {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), anyhow::Error> { async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client // Create fastembed client
let fastembed_client = rig_fastembed::Client::new(); let fastembed_client = rig_fastembed::Client::new();
let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); let embedding_model =
fastembed_client.text_embedding_model(&FastembedTextModel::AllMiniLML6V2Q);
let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(vec![ .documents(vec![

View File

@ -7,7 +7,7 @@ use rig::{
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed, Embed,
}; };
use rig_fastembed::EmbeddingModel; use rig_fastembed::TextEmbeddingModel;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::Path; use std::path::Path;
@ -51,7 +51,7 @@ async fn main() -> Result<(), anyhow::Error> {
UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean); UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean);
let embedding_model = let embedding_model =
EmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info); TextEmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info);
// Create documents // Create documents
let documents = vec![ let documents = vec![

View File

@ -1,8 +1,11 @@
use std::sync::Arc; use std::sync::Arc;
pub use fastembed::EmbeddingModel as FastembedModel; pub use fastembed::EmbeddingModel as FastembedTextModel;
pub use fastembed::ImageEmbeddingModel as FastembedImageModel;
use fastembed::{ use fastembed::{
InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel, ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, InitOptions,
InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
UserDefinedImageEmbeddingModel,
}; };
use rig::{ use rig::{
embeddings::{self, EmbeddingError, EmbeddingsBuilder}, embeddings::{self, EmbeddingError, EmbeddingsBuilder},
@ -24,7 +27,7 @@ impl Client {
Self Self
} }
/// Create an embedding model with the given name. /// Create a text embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known. /// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims` /// If this is the case, it's better to use function `embedding_model_with_ndims`
/// ///
@ -37,10 +40,28 @@ impl Client {
/// ///
/// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ``` /// ```
pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel { pub fn text_embedding_model(&self, model: &FastembedTextModel) -> TextEmbeddingModel {
let ndims = fetch_model_ndims(model); TextEmbeddingModel::new(model)
}
EmbeddingModel::new(model, ndims) /// Create an image embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig_fastembed::{Client, FastembedModel};
///
/// // Initialize the OpenAI client
/// let fastembed_client = Client::new("your-open-ai-api-key");
///
/// let embedding_model = fastembed_client.image_embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ```
pub fn image_embedding_model(
&self,
model: &fastembed::ImageEmbeddingModel,
) -> ImageEmbeddingModel {
ImageEmbeddingModel::new(model)
} }
/// Create an embedding builder with the given embedding model. /// Create an embedding builder with the given embedding model.
@ -59,23 +80,23 @@ impl Client {
/// .await /// .await
/// .expect("Failed to embed documents"); /// .expect("Failed to embed documents");
/// ``` /// ```
pub fn embeddings<D: Embed>( pub fn text_embeddings<D: Embed>(
&self, &self,
model: &fastembed::EmbeddingModel, model: &FastembedTextModel,
) -> EmbeddingsBuilder<EmbeddingModel, D> { ) -> EmbeddingsBuilder<TextEmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model)) EmbeddingsBuilder::new(self.text_embedding_model(model))
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct EmbeddingModel { pub struct TextEmbeddingModel {
embedder: Arc<TextEmbedding>, embedder: Arc<TextEmbedding>,
pub model: FastembedModel, pub model: FastembedTextModel,
ndims: usize, ndims: usize,
} }
impl EmbeddingModel { impl TextEmbeddingModel {
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self { pub fn new(model: &fastembed::EmbeddingModel) -> Self {
let embedder = Arc::new( let embedder = Arc::new(
TextEmbedding::try_new( TextEmbedding::try_new(
InitOptions::new(model.to_owned()).with_show_download_progress(true), InitOptions::new(model.to_owned()).with_show_download_progress(true),
@ -83,6 +104,8 @@ impl EmbeddingModel {
.unwrap(), .unwrap(),
); );
let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
Self { Self {
embedder, embedder,
model: model.to_owned(), model: model.to_owned(),
@ -93,7 +116,7 @@ impl EmbeddingModel {
pub fn new_from_user_defined( pub fn new_from_user_defined(
user_defined_model: UserDefinedEmbeddingModel, user_defined_model: UserDefinedEmbeddingModel,
ndims: usize, ndims: usize,
model_info: &ModelInfo<FastembedModel>, model_info: &ModelInfo<FastembedTextModel>,
) -> Self { ) -> Self {
let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined( let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
user_defined_model, user_defined_model,
@ -111,7 +134,7 @@ impl EmbeddingModel {
} }
} }
impl embeddings::EmbeddingModel for EmbeddingModel { impl embeddings::EmbeddingModel for TextEmbeddingModel {
const MAX_DOCUMENTS: usize = 1024; const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize { fn ndims(&self) -> usize {
@ -142,35 +165,76 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
} }
} }
/// As seen on the text embedding model cards file: <https://github.com/Anush008/fastembed-rs/blob/main/src/models/text_embedding.rs> #[derive(Clone)]
pub fn fetch_model_ndims(model: &FastembedModel) -> usize { pub struct ImageEmbeddingModel {
match model { embedder: Arc<ImageEmbedding>,
FastembedModel::AllMiniLML6V2 pub model: FastembedImageModel,
| FastembedModel::AllMiniLML6V2Q ndims: usize,
| FastembedModel::AllMiniLML12V2 }
| FastembedModel::AllMiniLML12V2Q
| FastembedModel::BGESmallENV15 impl ImageEmbeddingModel {
| FastembedModel::BGESmallENV15Q pub fn new(model: &FastembedImageModel) -> Self {
| FastembedModel::ParaphraseMLMiniLML12V2Q let embedder = Arc::new(
| FastembedModel::ParaphraseMLMiniLML12V2 ImageEmbedding::try_new(
| FastembedModel::MultilingualE5Small => 384, ImageInitOptions::new(model.to_owned()).with_show_download_progress(true),
FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512, )
FastembedModel::BGEBaseENV15 .unwrap(),
| FastembedModel::BGEBaseENV15Q );
| FastembedModel::NomicEmbedTextV1
| FastembedModel::NomicEmbedTextV15 let ndims = ImageEmbedding::get_model_info(model).dim;
| FastembedModel::NomicEmbedTextV15Q
| FastembedModel::ParaphraseMLMpnetBaseV2 Self {
| FastembedModel::MultilingualE5Base embedder,
| FastembedModel::GTEBaseENV15 model: model.to_owned(),
| FastembedModel::GTEBaseENV15Q ndims,
| FastembedModel::JinaEmbeddingsV2BaseCode => 768, }
FastembedModel::BGELargeENV15 }
| FastembedModel::BGELargeENV15Q
| FastembedModel::MultilingualE5Large pub fn new_from_user_defined(
| FastembedModel::MxbaiEmbedLargeV1 user_defined_model: UserDefinedImageEmbeddingModel,
| FastembedModel::MxbaiEmbedLargeV1Q ndims: usize,
| FastembedModel::GTELargeENV15 model_info: &ModelInfo<FastembedImageModel>,
| FastembedModel::GTELargeENV15Q => 1024, ) -> Self {
let fastembed_embedding_model = ImageEmbedding::try_new_from_user_defined(
user_defined_model,
ImageInitOptionsUserDefined::default(),
)
.unwrap();
let embedder = Arc::new(fastembed_embedding_model);
Self {
embedder,
model: model_info.model.to_owned(),
ndims,
}
}
}
impl rig::embeddings::embedding::ImageEmbeddingModel for ImageEmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims
}
async fn embed_images(
&self,
documents: impl IntoIterator<Item = Vec<u8>>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let images: Vec<Vec<u8>> = documents.into_iter().collect();
let images: Vec<&[u8]> = images.iter().map(|x| x.as_slice()).collect();
let images_as_vec = self.embedder.embed_bytes(&images, None).unwrap();
let docs = images_as_vec
.into_iter()
.map(|embedding| embeddings::Embedding {
document: String::new(),
vec: embedding.into_iter().map(|f| f as f64).collect(),
})
.collect::<Vec<embeddings::Embedding>>();
Ok(docs)
} }
} }