diff --git a/rig-fastembed/src/lib.rs b/rig-fastembed/src/lib.rs index 9bf3a7a..3b066efe 100644 --- a/rig-fastembed/src/lib.rs +++ b/rig-fastembed/src/lib.rs @@ -38,9 +38,7 @@ impl Client { /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); /// ``` pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel { - let ndims = fetch_model_ndims(model); - - EmbeddingModel::new(model, ndims) + EmbeddingModel::new(model) } /// Create an embedding builder with the given embedding model. @@ -75,7 +73,7 @@ pub struct EmbeddingModel { } impl EmbeddingModel { - pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self { + pub fn new(model: &fastembed::EmbeddingModel) -> Self { let embedder = Arc::new( TextEmbedding::try_new( InitOptions::new(model.to_owned()).with_show_download_progress(true), @@ -83,6 +81,8 @@ impl EmbeddingModel { .unwrap(), ); + let ndims = TextEmbedding::get_model_info(&model).unwrap().dim; + Self { embedder, model: model.to_owned(), @@ -141,37 +141,3 @@ impl embeddings::EmbeddingModel for EmbeddingModel { Ok(docs) } } - -/// As seen on the text embedding model cards file: -pub fn fetch_model_ndims(model: &FastembedModel) -> usize { - match model { - FastembedModel::AllMiniLML6V2 - | FastembedModel::AllMiniLML6V2Q - | FastembedModel::AllMiniLML12V2 - | FastembedModel::AllMiniLML12V2Q - | FastembedModel::BGESmallENV15 - | FastembedModel::BGESmallENV15Q - | FastembedModel::ParaphraseMLMiniLML12V2Q - | FastembedModel::ParaphraseMLMiniLML12V2 - | FastembedModel::MultilingualE5Small => 384, - FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512, - FastembedModel::BGEBaseENV15 - | FastembedModel::BGEBaseENV15Q - | FastembedModel::NomicEmbedTextV1 - | FastembedModel::NomicEmbedTextV15 - | FastembedModel::NomicEmbedTextV15Q - | FastembedModel::ParaphraseMLMpnetBaseV2 - | FastembedModel::MultilingualE5Base - | FastembedModel::GTEBaseENV15 - | FastembedModel::GTEBaseENV15Q - | FastembedModel::JinaEmbeddingsV2BaseCode => 768, - FastembedModel::BGELargeENV15 - | FastembedModel::BGELargeENV15Q - | FastembedModel::MultilingualE5Large - | FastembedModel::MxbaiEmbedLargeV1 - | FastembedModel::MxbaiEmbedLargeV1Q - | FastembedModel::GTELargeENV15 - | FastembedModel::ModernBertEmbedLarge - | FastembedModel::GTELargeENV15Q => 1024, - } -}