fix: gemini embeddings does not work for multiple documents

This commit is contained in:
Joshua Mo 2025-04-09 18:47:20 +01:00
parent 92c91d23c3
commit 10757903c0
2 changed files with 34 additions and 17 deletions

View File

@ -107,8 +107,11 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
// Generate the embeddings for each batch.
.map(|text| async {
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
println!("{ids:?}");
println!("{docs:?}");
let embeddings = self.model.embed_texts(docs).await?;
println!("Embeddings vec len: {}", embeddings.len());
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
})
// Parallelize the embeddings generation over 10 concurrent requests
@ -128,6 +131,8 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
)
.await?;
// println!("{embeddings:?}");
// Merge the embeddings with their respective documents
Ok(docs
.into_iter()

View File

@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
}
}
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect();
let mut request_body = json!({
"model": format!("models/{}", self.model),
"content": {
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
},
});
let documents: Vec<String> = documents.into_iter().collect();
if let Some(ndims) = self.ndims {
request_body["output_dimensionality"] = json!(ndims);
}
// Google batch embed requests. See docstrings for API ref link.
let requests: Vec<_> = documents
.iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
let response = self
.client
.post(&format!("/v1beta/models/{}:embedContent", self.model))
.post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
.json(&request_body)
.send()
.await?
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
match response {
ApiResponse::Ok(response) => {
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
Ok(documents
let docs = documents
.into_iter()
.zip(response.embedding.values.chunks(chunk_size))
.zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.to_vec(),
vec: embedding.values,
})
.collect())
.collect();
Ok(docs)
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
@ -196,7 +208,7 @@ mod gemini_api_types {
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: EmbeddingValues,
pub embeddings: Vec<EmbeddingValues>,
}
#[derive(Debug, Deserialize)]