This commit is contained in:
Joshua Mo 2025-04-18 12:43:35 +01:00 committed by GitHub
commit b65c078261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 138 additions and 52 deletions

View File

@ -13,7 +13,7 @@ serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
tracing = "0.1.40"
schemars = "0.8.16"
fastembed = "4.4.0"
fastembed = "4.6.0"
[dev-dependencies]
anyhow = "1.0.75"
@ -26,3 +26,6 @@ required-features = ["rig-core/derive"]
[[example]]
name = "vector_search_local"
required-features = ["rig-core/derive"]
[[example]]
name = "embed_images"

View File

@ -0,0 +1,18 @@
use rig::embeddings::embedding::ImageEmbeddingModel as _;
use rig_fastembed::FastembedImageModel;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let fastembed_client = rig_fastembed::Client::new();
let embedding_model =
fastembed_client.image_embedding_model(&FastembedImageModel::NomicEmbedVisionV15);
let bytes = std::fs::read("image.png").unwrap();
let res = embedding_model.embed_image(&bytes).await.unwrap();
println!("{}", res.vec.len());
Ok(())
}

View File

@ -3,7 +3,7 @@ use rig::{
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed,
};
use rig_fastembed::FastembedModel;
use rig_fastembed::FastembedTextModel;
use serde::{Deserialize, Serialize};
// Shape of data that needs to be RAG'ed.
@ -18,10 +18,11 @@ struct WordDefinition {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
// Create fastembed client
let fastembed_client = rig_fastembed::Client::new();
let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
let embedding_model =
fastembed_client.text_embedding_model(&FastembedTextModel::AllMiniLML6V2Q);
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(vec![

View File

@ -7,7 +7,7 @@ use rig::{
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed,
};
use rig_fastembed::EmbeddingModel;
use rig_fastembed::TextEmbeddingModel;
use serde::{Deserialize, Serialize};
use std::path::Path;
@ -51,7 +51,7 @@ async fn main() -> Result<(), anyhow::Error> {
UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean);
let embedding_model =
EmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info);
TextEmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info);
// Create documents
let documents = vec![

View File

@ -1,8 +1,11 @@
use std::sync::Arc;
pub use fastembed::EmbeddingModel as FastembedModel;
pub use fastembed::EmbeddingModel as FastembedTextModel;
pub use fastembed::ImageEmbeddingModel as FastembedImageModel;
use fastembed::{
InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, InitOptions,
InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
UserDefinedImageEmbeddingModel,
};
use rig::{
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
@ -24,7 +27,7 @@ impl Client {
Self
}
/// Create an embedding model with the given name.
/// Create a text embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
@ -37,10 +40,28 @@ impl Client {
///
/// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ```
pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
let ndims = fetch_model_ndims(model);
pub fn text_embedding_model(&self, model: &FastembedTextModel) -> TextEmbeddingModel {
TextEmbeddingModel::new(model)
}
EmbeddingModel::new(model, ndims)
/// Create an image embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig_fastembed::{Client, FastembedModel};
///
/// // Initialize the OpenAI client
/// let fastembed_client = Client::new("your-open-ai-api-key");
///
/// let embedding_model = fastembed_client.image_embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ```
pub fn image_embedding_model(
&self,
model: &fastembed::ImageEmbeddingModel,
) -> ImageEmbeddingModel {
ImageEmbeddingModel::new(model)
}
/// Create an embedding builder with the given embedding model.
@ -59,23 +80,23 @@ impl Client {
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings<D: Embed>(
pub fn text_embeddings<D: Embed>(
&self,
model: &fastembed::EmbeddingModel,
) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model))
model: &FastembedTextModel,
) -> EmbeddingsBuilder<TextEmbeddingModel, D> {
EmbeddingsBuilder::new(self.text_embedding_model(model))
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
pub struct TextEmbeddingModel {
embedder: Arc<TextEmbedding>,
pub model: FastembedModel,
pub model: FastembedTextModel,
ndims: usize,
}
impl EmbeddingModel {
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
impl TextEmbeddingModel {
pub fn new(model: &fastembed::EmbeddingModel) -> Self {
let embedder = Arc::new(
TextEmbedding::try_new(
InitOptions::new(model.to_owned()).with_show_download_progress(true),
@ -83,6 +104,8 @@ impl EmbeddingModel {
.unwrap(),
);
let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
Self {
embedder,
model: model.to_owned(),
@ -93,7 +116,7 @@ impl EmbeddingModel {
pub fn new_from_user_defined(
user_defined_model: UserDefinedEmbeddingModel,
ndims: usize,
model_info: &ModelInfo<FastembedModel>,
model_info: &ModelInfo<FastembedTextModel>,
) -> Self {
let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
user_defined_model,
@ -111,7 +134,7 @@ impl EmbeddingModel {
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
impl embeddings::EmbeddingModel for TextEmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
@ -142,35 +165,76 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
}
}
/// As seen on the text embedding model cards file: <https://github.com/Anush008/fastembed-rs/blob/main/src/models/text_embedding.rs>
pub fn fetch_model_ndims(model: &FastembedModel) -> usize {
match model {
FastembedModel::AllMiniLML6V2
| FastembedModel::AllMiniLML6V2Q
| FastembedModel::AllMiniLML12V2
| FastembedModel::AllMiniLML12V2Q
| FastembedModel::BGESmallENV15
| FastembedModel::BGESmallENV15Q
| FastembedModel::ParaphraseMLMiniLML12V2Q
| FastembedModel::ParaphraseMLMiniLML12V2
| FastembedModel::MultilingualE5Small => 384,
FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512,
FastembedModel::BGEBaseENV15
| FastembedModel::BGEBaseENV15Q
| FastembedModel::NomicEmbedTextV1
| FastembedModel::NomicEmbedTextV15
| FastembedModel::NomicEmbedTextV15Q
| FastembedModel::ParaphraseMLMpnetBaseV2
| FastembedModel::MultilingualE5Base
| FastembedModel::GTEBaseENV15
| FastembedModel::GTEBaseENV15Q
| FastembedModel::JinaEmbeddingsV2BaseCode => 768,
FastembedModel::BGELargeENV15
| FastembedModel::BGELargeENV15Q
| FastembedModel::MultilingualE5Large
| FastembedModel::MxbaiEmbedLargeV1
| FastembedModel::MxbaiEmbedLargeV1Q
| FastembedModel::GTELargeENV15
| FastembedModel::GTELargeENV15Q => 1024,
#[derive(Clone)]
pub struct ImageEmbeddingModel {
embedder: Arc<ImageEmbedding>,
pub model: FastembedImageModel,
ndims: usize,
}
impl ImageEmbeddingModel {
pub fn new(model: &FastembedImageModel) -> Self {
let embedder = Arc::new(
ImageEmbedding::try_new(
ImageInitOptions::new(model.to_owned()).with_show_download_progress(true),
)
.unwrap(),
);
let ndims = ImageEmbedding::get_model_info(model).dim;
Self {
embedder,
model: model.to_owned(),
ndims,
}
}
pub fn new_from_user_defined(
user_defined_model: UserDefinedImageEmbeddingModel,
ndims: usize,
model_info: &ModelInfo<FastembedImageModel>,
) -> Self {
let fastembed_embedding_model = ImageEmbedding::try_new_from_user_defined(
user_defined_model,
ImageInitOptionsUserDefined::default(),
)
.unwrap();
let embedder = Arc::new(fastembed_embedding_model);
Self {
embedder,
model: model_info.model.to_owned(),
ndims,
}
}
}
impl rig::embeddings::embedding::ImageEmbeddingModel for ImageEmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims
}
async fn embed_images(
&self,
documents: impl IntoIterator<Item = Vec<u8>>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let images: Vec<Vec<u8>> = documents.into_iter().collect();
let images: Vec<&[u8]> = images.iter().map(|x| x.as_slice()).collect();
let images_as_vec = self.embedder.embed_bytes(&images, None).unwrap();
let docs = images_as_vec
.into_iter()
.map(|embedding| embeddings::Embedding {
document: String::new(),
vec: embedding.into_iter().map(|f| f as f64).collect(),
})
.collect::<Vec<embeddings::Embedding>>();
Ok(docs)
}
}