mirror of https://github.com/0xplaygrounds/rig
Merge 12720160c2
into 33e8fc7a65
This commit is contained in:
commit
b65c078261
|
@ -13,7 +13,7 @@ serde = { version = "1.0.193", features = ["derive"] }
|
|||
serde_json = "1.0.108"
|
||||
tracing = "0.1.40"
|
||||
schemars = "0.8.16"
|
||||
fastembed = "4.4.0"
|
||||
fastembed = "4.6.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.75"
|
||||
|
@ -26,3 +26,6 @@ required-features = ["rig-core/derive"]
|
|||
[[example]]
|
||||
name = "vector_search_local"
|
||||
required-features = ["rig-core/derive"]
|
||||
|
||||
[[example]]
|
||||
name = "embed_images"
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
use rig::embeddings::embedding::ImageEmbeddingModel as _;
|
||||
use rig_fastembed::FastembedImageModel;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
let fastembed_client = rig_fastembed::Client::new();
|
||||
|
||||
let embedding_model =
|
||||
fastembed_client.image_embedding_model(&FastembedImageModel::NomicEmbedVisionV15);
|
||||
|
||||
let bytes = std::fs::read("image.png").unwrap();
|
||||
|
||||
let res = embedding_model.embed_image(&bytes).await.unwrap();
|
||||
|
||||
println!("{}", res.vec.len());
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -3,7 +3,7 @@ use rig::{
|
|||
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
|
||||
Embed,
|
||||
};
|
||||
use rig_fastembed::FastembedModel;
|
||||
use rig_fastembed::FastembedTextModel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Shape of data that needs to be RAG'ed.
|
||||
|
@ -18,10 +18,11 @@ struct WordDefinition {
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create OpenAI client
|
||||
// Create fastembed client
|
||||
let fastembed_client = rig_fastembed::Client::new();
|
||||
|
||||
let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
|
||||
let embedding_model =
|
||||
fastembed_client.text_embedding_model(&FastembedTextModel::AllMiniLML6V2Q);
|
||||
|
||||
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
.documents(vec![
|
||||
|
|
|
@ -7,7 +7,7 @@ use rig::{
|
|||
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
|
||||
Embed,
|
||||
};
|
||||
use rig_fastembed::EmbeddingModel;
|
||||
use rig_fastembed::TextEmbeddingModel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
|
@ -51,7 +51,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean);
|
||||
|
||||
let embedding_model =
|
||||
EmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info);
|
||||
TextEmbeddingModel::new_from_user_defined(user_defined_model, 384, &test_model_info);
|
||||
|
||||
// Create documents
|
||||
let documents = vec![
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
pub use fastembed::EmbeddingModel as FastembedModel;
|
||||
pub use fastembed::EmbeddingModel as FastembedTextModel;
|
||||
pub use fastembed::ImageEmbeddingModel as FastembedImageModel;
|
||||
use fastembed::{
|
||||
InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
|
||||
ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, InitOptions,
|
||||
InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
|
||||
UserDefinedImageEmbeddingModel,
|
||||
};
|
||||
use rig::{
|
||||
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
|
||||
|
@ -24,7 +27,7 @@ impl Client {
|
|||
Self
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name.
|
||||
/// Create a text embedding model with the given name.
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
///
|
||||
|
@ -37,10 +40,28 @@ impl Client {
|
|||
///
|
||||
/// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
|
||||
/// ```
|
||||
pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
|
||||
let ndims = fetch_model_ndims(model);
|
||||
pub fn text_embedding_model(&self, model: &FastembedTextModel) -> TextEmbeddingModel {
|
||||
TextEmbeddingModel::new(model)
|
||||
}
|
||||
|
||||
EmbeddingModel::new(model, ndims)
|
||||
/// Create an image embedding model with the given name.
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig_fastembed::{Client, FastembedModel};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let fastembed_client = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embedding_model = fastembed_client.image_embedding_model(&FastembedModel::AllMiniLML6V2Q);
|
||||
/// ```
|
||||
pub fn image_embedding_model(
|
||||
&self,
|
||||
model: &fastembed::ImageEmbeddingModel,
|
||||
) -> ImageEmbeddingModel {
|
||||
ImageEmbeddingModel::new(model)
|
||||
}
|
||||
|
||||
/// Create an embedding builder with the given embedding model.
|
||||
|
@ -59,23 +80,23 @@ impl Client {
|
|||
/// .await
|
||||
/// .expect("Failed to embed documents");
|
||||
/// ```
|
||||
pub fn embeddings<D: Embed>(
|
||||
pub fn text_embeddings<D: Embed>(
|
||||
&self,
|
||||
model: &fastembed::EmbeddingModel,
|
||||
) -> EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
EmbeddingsBuilder::new(self.embedding_model(model))
|
||||
model: &FastembedTextModel,
|
||||
) -> EmbeddingsBuilder<TextEmbeddingModel, D> {
|
||||
EmbeddingsBuilder::new(self.text_embedding_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
pub struct TextEmbeddingModel {
|
||||
embedder: Arc<TextEmbedding>,
|
||||
pub model: FastembedModel,
|
||||
pub model: FastembedTextModel,
|
||||
ndims: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
|
||||
impl TextEmbeddingModel {
|
||||
pub fn new(model: &fastembed::EmbeddingModel) -> Self {
|
||||
let embedder = Arc::new(
|
||||
TextEmbedding::try_new(
|
||||
InitOptions::new(model.to_owned()).with_show_download_progress(true),
|
||||
|
@ -83,6 +104,8 @@ impl EmbeddingModel {
|
|||
.unwrap(),
|
||||
);
|
||||
|
||||
let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
|
||||
|
||||
Self {
|
||||
embedder,
|
||||
model: model.to_owned(),
|
||||
|
@ -93,7 +116,7 @@ impl EmbeddingModel {
|
|||
pub fn new_from_user_defined(
|
||||
user_defined_model: UserDefinedEmbeddingModel,
|
||||
ndims: usize,
|
||||
model_info: &ModelInfo<FastembedModel>,
|
||||
model_info: &ModelInfo<FastembedTextModel>,
|
||||
) -> Self {
|
||||
let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
|
||||
user_defined_model,
|
||||
|
@ -111,7 +134,7 @@ impl EmbeddingModel {
|
|||
}
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
impl embeddings::EmbeddingModel for TextEmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 1024;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
|
@ -142,35 +165,76 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
|||
}
|
||||
}
|
||||
|
||||
/// As seen on the text embedding model cards file: <https://github.com/Anush008/fastembed-rs/blob/main/src/models/text_embedding.rs>
|
||||
pub fn fetch_model_ndims(model: &FastembedModel) -> usize {
|
||||
match model {
|
||||
FastembedModel::AllMiniLML6V2
|
||||
| FastembedModel::AllMiniLML6V2Q
|
||||
| FastembedModel::AllMiniLML12V2
|
||||
| FastembedModel::AllMiniLML12V2Q
|
||||
| FastembedModel::BGESmallENV15
|
||||
| FastembedModel::BGESmallENV15Q
|
||||
| FastembedModel::ParaphraseMLMiniLML12V2Q
|
||||
| FastembedModel::ParaphraseMLMiniLML12V2
|
||||
| FastembedModel::MultilingualE5Small => 384,
|
||||
FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512,
|
||||
FastembedModel::BGEBaseENV15
|
||||
| FastembedModel::BGEBaseENV15Q
|
||||
| FastembedModel::NomicEmbedTextV1
|
||||
| FastembedModel::NomicEmbedTextV15
|
||||
| FastembedModel::NomicEmbedTextV15Q
|
||||
| FastembedModel::ParaphraseMLMpnetBaseV2
|
||||
| FastembedModel::MultilingualE5Base
|
||||
| FastembedModel::GTEBaseENV15
|
||||
| FastembedModel::GTEBaseENV15Q
|
||||
| FastembedModel::JinaEmbeddingsV2BaseCode => 768,
|
||||
FastembedModel::BGELargeENV15
|
||||
| FastembedModel::BGELargeENV15Q
|
||||
| FastembedModel::MultilingualE5Large
|
||||
| FastembedModel::MxbaiEmbedLargeV1
|
||||
| FastembedModel::MxbaiEmbedLargeV1Q
|
||||
| FastembedModel::GTELargeENV15
|
||||
| FastembedModel::GTELargeENV15Q => 1024,
|
||||
#[derive(Clone)]
|
||||
pub struct ImageEmbeddingModel {
|
||||
embedder: Arc<ImageEmbedding>,
|
||||
pub model: FastembedImageModel,
|
||||
ndims: usize,
|
||||
}
|
||||
|
||||
impl ImageEmbeddingModel {
|
||||
pub fn new(model: &FastembedImageModel) -> Self {
|
||||
let embedder = Arc::new(
|
||||
ImageEmbedding::try_new(
|
||||
ImageInitOptions::new(model.to_owned()).with_show_download_progress(true),
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let ndims = ImageEmbedding::get_model_info(model).dim;
|
||||
|
||||
Self {
|
||||
embedder,
|
||||
model: model.to_owned(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_user_defined(
|
||||
user_defined_model: UserDefinedImageEmbeddingModel,
|
||||
ndims: usize,
|
||||
model_info: &ModelInfo<FastembedImageModel>,
|
||||
) -> Self {
|
||||
let fastembed_embedding_model = ImageEmbedding::try_new_from_user_defined(
|
||||
user_defined_model,
|
||||
ImageInitOptionsUserDefined::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let embedder = Arc::new(fastembed_embedding_model);
|
||||
|
||||
Self {
|
||||
embedder,
|
||||
model: model_info.model.to_owned(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl rig::embeddings::embedding::ImageEmbeddingModel for ImageEmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 1024;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
self.ndims
|
||||
}
|
||||
|
||||
async fn embed_images(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = Vec<u8>>,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let images: Vec<Vec<u8>> = documents.into_iter().collect();
|
||||
let images: Vec<&[u8]> = images.iter().map(|x| x.as_slice()).collect();
|
||||
|
||||
let images_as_vec = self.embedder.embed_bytes(&images, None).unwrap();
|
||||
|
||||
let docs = images_as_vec
|
||||
.into_iter()
|
||||
.map(|embedding| embeddings::Embedding {
|
||||
document: String::new(),
|
||||
vec: embedding.into_iter().map(|f| f as f64).collect(),
|
||||
})
|
||||
.collect::<Vec<embeddings::Embedding>>();
|
||||
|
||||
Ok(docs)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue