feat: remove method for manually fetching model info

This commit is contained in:
Joshua Mo 2025-04-11 11:24:14 +01:00
parent 8fdd2e9f92
commit b7ec64e7d6
1 changed files with 4 additions and 38 deletions

View File

@ -38,9 +38,7 @@ impl Client {
/// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
/// ``` /// ```
pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel { pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
let ndims = fetch_model_ndims(model); EmbeddingModel::new(model)
EmbeddingModel::new(model, ndims)
} }
/// Create an embedding builder with the given embedding model. /// Create an embedding builder with the given embedding model.
@ -75,7 +73,7 @@ pub struct EmbeddingModel {
} }
impl EmbeddingModel { impl EmbeddingModel {
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self { pub fn new(model: &fastembed::EmbeddingModel) -> Self {
let embedder = Arc::new( let embedder = Arc::new(
TextEmbedding::try_new( TextEmbedding::try_new(
InitOptions::new(model.to_owned()).with_show_download_progress(true), InitOptions::new(model.to_owned()).with_show_download_progress(true),
@ -83,6 +81,8 @@ impl EmbeddingModel {
.unwrap(), .unwrap(),
); );
let ndims = TextEmbedding::get_model_info(&model).unwrap().dim;
Self { Self {
embedder, embedder,
model: model.to_owned(), model: model.to_owned(),
@ -141,37 +141,3 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
Ok(docs) Ok(docs)
} }
} }
/// As seen on the text embedding model cards file: <https://github.com/Anush008/fastembed-rs/blob/main/src/models/text_embedding.rs>
pub fn fetch_model_ndims(model: &FastembedModel) -> usize {
match model {
FastembedModel::AllMiniLML6V2
| FastembedModel::AllMiniLML6V2Q
| FastembedModel::AllMiniLML12V2
| FastembedModel::AllMiniLML12V2Q
| FastembedModel::BGESmallENV15
| FastembedModel::BGESmallENV15Q
| FastembedModel::ParaphraseMLMiniLML12V2Q
| FastembedModel::ParaphraseMLMiniLML12V2
| FastembedModel::MultilingualE5Small => 384,
FastembedModel::BGESmallZHV15 | FastembedModel::ClipVitB32 => 512,
FastembedModel::BGEBaseENV15
| FastembedModel::BGEBaseENV15Q
| FastembedModel::NomicEmbedTextV1
| FastembedModel::NomicEmbedTextV15
| FastembedModel::NomicEmbedTextV15Q
| FastembedModel::ParaphraseMLMpnetBaseV2
| FastembedModel::MultilingualE5Base
| FastembedModel::GTEBaseENV15
| FastembedModel::GTEBaseENV15Q
| FastembedModel::JinaEmbeddingsV2BaseCode => 768,
FastembedModel::BGELargeENV15
| FastembedModel::BGELargeENV15Q
| FastembedModel::MultilingualE5Large
| FastembedModel::MxbaiEmbedLargeV1
| FastembedModel::MxbaiEmbedLargeV1Q
| FastembedModel::GTELargeENV15
| FastembedModel::ModernBertEmbedLarge
| FastembedModel::GTELargeENV15Q => 1024,
}
}