mirror of https://github.com/0xplaygrounds/rig
fix: gemini embeddings does not work for multiple documents
This commit is contained in:
parent
92c91d23c3
commit
10757903c0
|
@ -107,8 +107,11 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
|
|||
// Generate the embeddings for each batch.
|
||||
.map(|text| async {
|
||||
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
|
||||
println!("{ids:?}");
|
||||
println!("{docs:?}");
|
||||
|
||||
let embeddings = self.model.embed_texts(docs).await?;
|
||||
println!("Embeddings vec len: {}", embeddings.len());
|
||||
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
|
||||
})
|
||||
// Parallelize the embeddings generation over 10 concurrent requests
|
||||
|
@ -128,6 +131,8 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
|
|||
)
|
||||
.await?;
|
||||
|
||||
// println!("{embeddings:?}");
|
||||
|
||||
// Merge the embeddings with their respective documents
|
||||
Ok(docs
|
||||
.into_iter()
|
||||
|
|
|
@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
|||
}
|
||||
}
|
||||
|
||||
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String> + Send,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let documents: Vec<_> = documents.into_iter().collect();
|
||||
let mut request_body = json!({
|
||||
"model": format!("models/{}", self.model),
|
||||
"content": {
|
||||
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
|
||||
},
|
||||
});
|
||||
let documents: Vec<String> = documents.into_iter().collect();
|
||||
|
||||
if let Some(ndims) = self.ndims {
|
||||
request_body["output_dimensionality"] = json!(ndims);
|
||||
}
|
||||
// Google batch embed requests. See docstrings for API ref link.
|
||||
let requests: Vec<_> = documents
|
||||
.iter()
|
||||
.map(|doc| {
|
||||
json!({
|
||||
"model": format!("models/{}", self.model),
|
||||
"content": json!({
|
||||
"parts": [json!({
|
||||
"text": doc.to_string()
|
||||
})]
|
||||
}),
|
||||
"output_dimensionality": self.ndims,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request_body = json!({ "requests": requests });
|
||||
|
||||
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("/v1beta/models/{}:embedContent", self.model))
|
||||
.post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?
|
||||
|
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
|||
|
||||
match response {
|
||||
ApiResponse::Ok(response) => {
|
||||
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
|
||||
Ok(documents
|
||||
let docs = documents
|
||||
.into_iter()
|
||||
.zip(response.embedding.values.chunks(chunk_size))
|
||||
.zip(response.embeddings)
|
||||
.map(|(document, embedding)| embeddings::Embedding {
|
||||
document,
|
||||
vec: embedding.to_vec(),
|
||||
vec: embedding.values,
|
||||
})
|
||||
.collect())
|
||||
.collect();
|
||||
|
||||
Ok(docs)
|
||||
}
|
||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||
}
|
||||
|
@ -196,7 +208,7 @@ mod gemini_api_types {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub embedding: EmbeddingValues,
|
||||
pub embeddings: Vec<EmbeddingValues>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
Loading…
Reference in New Issue