fix: gemini embeddings does not work for multiple documents

This commit is contained in:
Joshua Mo 2025-04-09 18:47:20 +01:00
parent 92c91d23c3
commit 10757903c0
2 changed files with 34 additions and 17 deletions

View File

@ -107,8 +107,11 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
// Generate the embeddings for each batch. // Generate the embeddings for each batch.
.map(|text| async { .map(|text| async {
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
println!("{ids:?}");
println!("{docs:?}");
let embeddings = self.model.embed_texts(docs).await?; let embeddings = self.model.embed_texts(docs).await?;
println!("Embeddings vec len: {}", embeddings.len());
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>()) Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
}) })
// Parallelize the embeddings generation over 10 concurrent requests // Parallelize the embeddings generation over 10 concurrent requests
@ -128,6 +131,8 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
) )
.await?; .await?;
// println!("{embeddings:?}");
// Merge the embeddings with their respective documents // Merge the embeddings with their respective documents
Ok(docs Ok(docs
.into_iter() .into_iter()

View File

@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
} }
} }
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
#[cfg_attr(feature = "worker", worker::send)] #[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts( async fn embed_texts(
&self, &self,
documents: impl IntoIterator<Item = String> + Send, documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect(); let documents: Vec<String> = documents.into_iter().collect();
let mut request_body = json!({
"model": format!("models/{}", self.model),
"content": {
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
},
});
if let Some(ndims) = self.ndims { // Google batch embed requests. See docstrings for API ref link.
request_body["output_dimensionality"] = json!(ndims); let requests: Vec<_> = documents
} .iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
let response = self let response = self
.client .client
.post(&format!("/v1beta/models/{}:embedContent", self.model)) .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
.json(&request_body) .json(&request_body)
.send() .send()
.await? .await?
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
match response { match response {
ApiResponse::Ok(response) => { ApiResponse::Ok(response) => {
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); let docs = documents
Ok(documents
.into_iter() .into_iter()
.zip(response.embedding.values.chunks(chunk_size)) .zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding { .map(|(document, embedding)| embeddings::Embedding {
document, document,
vec: embedding.to_vec(), vec: embedding.values,
}) })
.collect()) .collect();
Ok(docs)
} }
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
} }
@ -196,7 +208,7 @@ mod gemini_api_types {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub embedding: EmbeddingValues, pub embeddings: Vec<EmbeddingValues>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]