fix: gemini embeddings does not work for multiple documents (#386)

* fix: gemini embeddings does not work for multiple documents

* refactor: remove unnecessary printlns
This commit is contained in:
Joshua Mo 2025-04-09 23:25:13 +01:00 committed by GitHub
parent bc11decc0b
commit d10d1cc73b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 17 deletions

View File

@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
}
}
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect();
let mut request_body = json!({
"model": format!("models/{}", self.model),
"content": {
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
},
});
let documents: Vec<String> = documents.into_iter().collect();
if let Some(ndims) = self.ndims {
request_body["output_dimensionality"] = json!(ndims);
}
// Google batch embed requests. See docstrings for API ref link.
let requests: Vec<_> = documents
.iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
let response = self
.client
.post(&format!("/v1beta/models/{}:embedContent", self.model))
.post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
.json(&request_body)
.send()
.await?
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
match response {
ApiResponse::Ok(response) => {
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
Ok(documents
let docs = documents
.into_iter()
.zip(response.embedding.values.chunks(chunk_size))
.zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.to_vec(),
vec: embedding.values,
})
.collect())
.collect();
Ok(docs)
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
@ -196,7 +208,7 @@ mod gemini_api_types {
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: EmbeddingValues,
pub embeddings: Vec<EmbeddingValues>,
}
#[derive(Debug, Deserialize)]