feat: remove method for manually fetching model info

This commit is contained in:
Joshua Mo 2025-04-11 11:24:14 +01:00
parent 8fdd2e9f92
commit b7ec64e7d6
1 changed files with 4 additions and 38 deletions

View File

@ -38,9 +38,7 @@ impl Client {
/// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ```
pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
let ndims = fetch_model_ndims(model);
EmbeddingModel::new(model, ndims)
EmbeddingModel::new(model)
}
/// Create an embedding builder with the given embedding model.
@ -75,7 +73,7 @@ pub struct EmbeddingModel {
}
impl EmbeddingModel {
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
pub fn new(model: &fastembed::EmbeddingModel) -> Self {
let embedder = Arc::new(
TextEmbedding::try_new(
InitOptions::new(model.to_owned()).with_show_download_progress(true),
@ -83,6 +81,8 @@ impl EmbeddingModel {
.unwrap(),
);
let ndims = TextEmbedding::get_model_info(&model).unwrap().dim;
Self {
embedder,
model: model.to_owned(),
@ -141,37 +141,3 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
Ok(docs)
}
}
/// As seen on the text embedding model cards file: <https://github.com/Anush008/fastembed-rs/blob/main/src/models/text_embedding.rs>
pub fn fetch_model_ndims(model: &FastembedModel) -> usize {
match model {
FastembedModel::AllMiniLML6V2
| FastembedModel::AllMiniLML6V2Q
| FastembedModel::AllMiniLML12V2
| FastembedModel::AllMiniLML12V2Q
| FastembedModel::BGESmallENV15
| FastembedModel::BGESmallENV15Q
| FastembedModel::ParaphraseMLMiniLML12V2Q
| FastembedModel::ParaphraseMLMiniLML12V2
| FastembedModel::MultilingualE5Small => 384,
FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512,
FastembedModel::BGEBaseENV15
| FastembedModel::BGEBaseENV15Q
| FastembedModel::NomicEmbedTextV1
| FastembedModel::NomicEmbedTextV15
| FastembedModel::NomicEmbedTextV15Q
| FastembedModel::ParaphraseMLMpnetBaseV2
| FastembedModel::MultilingualE5Base
| FastembedModel::GTEBaseENV15
| FastembedModel::GTEBaseENV15Q
| FastembedModel::JinaEmbeddingsV2BaseCode => 768,
FastembedModel::BGELargeENV15
| FastembedModel::BGELargeENV15Q
| FastembedModel::MultilingualE5Large
| FastembedModel::MxbaiEmbedLargeV1
| FastembedModel::MxbaiEmbedLargeV1Q
| FastembedModel::GTELargeENV15
| FastembedModel::ModernBertEmbedLarge
| FastembedModel::GTELargeENV15Q => 1024,
}
}