diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 389e17b..39dc31e 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -107,8 +107,11 @@ impl EmbeddingsBuilder { // Generate the embeddings for each batch. .map(|text| async { let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); + println!("{ids:?}"); + println!("{docs:?}"); let embeddings = self.model.embed_texts(docs).await?; + println!("Embeddings vec len: {}", embeddings.len()); Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::>()) }) // Parallelize the embeddings generation over 10 concurrent requests @@ -128,6 +131,8 @@ impl EmbeddingsBuilder { ) .await?; + // println!("{embeddings:?}"); + // Merge the embeddings with their respective documents Ok(docs .into_iter() diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 87cdd32..b14d56c 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } + /// #[cfg_attr(feature = "worker", worker::send)] async fn embed_texts( &self, documents: impl IntoIterator + Send, ) -> Result, EmbeddingError> { - let documents: Vec<_> = documents.into_iter().collect(); - let mut request_body = json!({ - "model": format!("models/{}", self.model), - "content": { - "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), - }, - }); + let documents: Vec = documents.into_iter().collect(); - if let Some(ndims) = self.ndims { - request_body["output_dimensionality"] = json!(ndims); - } + // Google batch embed requests. See docstrings for API ref link. + let requests: Vec<_> = documents + .iter() + .map(|doc| { + json!({ + "model": format!("models/{}", self.model), + "content": json!({ + "parts": [json!({ + "text": doc.to_string() + })] + }), + "output_dimensionality": self.ndims, + }) + }) + .collect(); + + let request_body = json!({ "requests": requests }); + + println!("{}", serde_json::to_string_pretty(&request_body).unwrap()); let response = self .client - .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model)) .json(&request_body) .send() .await? @@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel { match response { ApiResponse::Ok(response) => { - let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); - Ok(documents + let docs = documents .into_iter() - .zip(response.embedding.values.chunks(chunk_size)) + .zip(response.embeddings) .map(|(document, embedding)| embeddings::Embedding { document, - vec: embedding.to_vec(), + vec: embedding.values, }) - .collect()) + .collect(); + + Ok(docs) } ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } @@ -196,7 +208,7 @@ mod gemini_api_types { #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { - pub embedding: EmbeddingValues, + pub embeddings: Vec, } #[derive(Debug, Deserialize)]