mirror of https://github.com/0xplaygrounds/rig
fix: gemini embeddings does not work for multiple documents
This commit is contained in:
parent
92c91d23c3
commit
10757903c0
|
@ -107,8 +107,11 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
|
||||||
// Generate the embeddings for each batch.
|
// Generate the embeddings for each batch.
|
||||||
.map(|text| async {
|
.map(|text| async {
|
||||||
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
|
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
|
||||||
|
println!("{ids:?}");
|
||||||
|
println!("{docs:?}");
|
||||||
|
|
||||||
let embeddings = self.model.embed_texts(docs).await?;
|
let embeddings = self.model.embed_texts(docs).await?;
|
||||||
|
println!("Embeddings vec len: {}", embeddings.len());
|
||||||
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
|
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
|
||||||
})
|
})
|
||||||
// Parallelize the embeddings generation over 10 concurrent requests
|
// Parallelize the embeddings generation over 10 concurrent requests
|
||||||
|
@ -128,6 +131,8 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// println!("{embeddings:?}");
|
||||||
|
|
||||||
// Merge the embeddings with their respective documents
|
// Merge the embeddings with their respective documents
|
||||||
Ok(docs
|
Ok(docs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
|
||||||
#[cfg_attr(feature = "worker", worker::send)]
|
#[cfg_attr(feature = "worker", worker::send)]
|
||||||
async fn embed_texts(
|
async fn embed_texts(
|
||||||
&self,
|
&self,
|
||||||
documents: impl IntoIterator<Item = String> + Send,
|
documents: impl IntoIterator<Item = String> + Send,
|
||||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||||
let documents: Vec<_> = documents.into_iter().collect();
|
let documents: Vec<String> = documents.into_iter().collect();
|
||||||
let mut request_body = json!({
|
|
||||||
"model": format!("models/{}", self.model),
|
|
||||||
"content": {
|
|
||||||
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(ndims) = self.ndims {
|
// Google batch embed requests. See docstrings for API ref link.
|
||||||
request_body["output_dimensionality"] = json!(ndims);
|
let requests: Vec<_> = documents
|
||||||
}
|
.iter()
|
||||||
|
.map(|doc| {
|
||||||
|
json!({
|
||||||
|
"model": format!("models/{}", self.model),
|
||||||
|
"content": json!({
|
||||||
|
"parts": [json!({
|
||||||
|
"text": doc.to_string()
|
||||||
|
})]
|
||||||
|
}),
|
||||||
|
"output_dimensionality": self.ndims,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request_body = json!({ "requests": requests });
|
||||||
|
|
||||||
|
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("/v1beta/models/{}:embedContent", self.model))
|
.post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
|
||||||
.json(&request_body)
|
.json(&request_body)
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
|
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
ApiResponse::Ok(response) => {
|
ApiResponse::Ok(response) => {
|
||||||
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
|
let docs = documents
|
||||||
Ok(documents
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(response.embedding.values.chunks(chunk_size))
|
.zip(response.embeddings)
|
||||||
.map(|(document, embedding)| embeddings::Embedding {
|
.map(|(document, embedding)| embeddings::Embedding {
|
||||||
document,
|
document,
|
||||||
vec: embedding.to_vec(),
|
vec: embedding.values,
|
||||||
})
|
})
|
||||||
.collect())
|
.collect();
|
||||||
|
|
||||||
|
Ok(docs)
|
||||||
}
|
}
|
||||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||||
}
|
}
|
||||||
|
@ -196,7 +208,7 @@ mod gemini_api_types {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub embedding: EmbeddingValues,
|
pub embeddings: Vec<EmbeddingValues>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
Loading…
Reference in New Issue