diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 3dac86b..3c5c7f3 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -127,3 +127,8 @@ required-features = ["audio"] [[example]] name = "hyperbolic_audio_generation" required-features = ["audio"] + + +[[example]] +name = "aliyun_embeddings" +required-features = ["derive"] diff --git a/rig-core/examples/aliyun_embeddings.rs b/rig-core/examples/aliyun_embeddings.rs new file mode 100644 index 0000000..eb771f0 --- /dev/null +++ b/rig-core/examples/aliyun_embeddings.rs @@ -0,0 +1,30 @@ +use rig::providers::aliyun; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + #[embed] + message: String, +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the Aliyun client + let client = aliyun::Client::from_env(); + + let embeddings = client + .embeddings(aliyun::embedding::EMBEDDING_V1) + .document(Greetings { + message: "Hello, world!".to_string(), + })? + .document(Greetings { + message: "Goodbye, world!".to_string(), + })? + .build() + .await + .expect("Failed to embed documents"); + + println!("{:?}", embeddings); + + Ok(()) +} diff --git a/rig-core/src/providers/aliyun/client.rs b/rig-core/src/providers/aliyun/client.rs new file mode 100644 index 0000000..3e095b3 --- /dev/null +++ b/rig-core/src/providers/aliyun/client.rs @@ -0,0 +1,162 @@ +use crate::{ + embeddings::{self}, + Embed, +}; +use serde::Deserialize; + +use super::embedding::EmbeddingModel; + +// ================================================================ +// Aliyun Gemini Client +// ================================================================ +const ALIYUN_API_BASE_URL: &str = "https://dashscope.aliyuncs.com"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + api_key: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new Aliyun client with the given API key. + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::Client; + /// + /// // Initialize the Aliyun client + /// let aliyun = Client::new("your-dashscope-api-key"); + /// ``` + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, ALIYUN_API_BASE_URL) + } + + /// Create a new Aliyun client with the given API key and base URL. + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::Client; + /// + /// // Initialize the Aliyun client with a custom base URL + /// let aliyun = Client::from_url("your-dashscope-api-key", "https://custom-dashscope-url.com"); + /// ``` + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + api_key: api_key.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + headers + }) + .build() + .expect("Aliyun reqwest client should build"), + } + } + + /// Create a new Aliyun client from the `DASHSCOPE_API_KEY` environment variable. + /// Panics if the environment variable is not set. + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::Client; + /// + /// // Initialize the Aliyun client from environment variable + /// let aliyun = Client::from_env(); + /// ``` + /// # Panics + /// This function will panic if the `DASHSCOPE_API_KEY` environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY not set"); + Self::new(&api_key) + } + + /// Create a POST request to the specified API endpoint path. + /// The Authorization header with the API key will be automatically added. + /// + /// # Arguments + /// * `path` - The API endpoint path to append to the base URL + /// + /// # Returns + /// A reqwest::RequestBuilder instance that can be further customized before sending + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + tracing::debug!("POST {}/{}", self.base_url, path); + self.http_client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + } + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::{Client, self}; + /// + /// // Initialize the Aliyun client + /// let aliyun = Client::new("your-dashscope-api-key"); + /// + /// let embedding_model = aliyun.embedding_model("your-model-name"); + /// ``` + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, None) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::{Client, self}; + /// + /// // Initialize the Aliyun client + /// let aliyun = Client::new("your-dashscope-api-key"); + /// + /// let embedding_model = aliyun.embedding_model_with_ndims("model-unknown-to-rig", 1024); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, Some(ndims)) + } + + /// Create an embedding builder with the given embedding model. + /// + /// # Example + /// ``` + /// use rig::providers::aliyun::{Client, self}; + /// + /// // Initialize the Aliyun client + /// let aliyun = Client::new("your-dashscope-api-key"); + /// + /// let embeddings = aliyun.embeddings("your-model-name") + /// .simple_document("doc0", "Hello, world!") + /// .simple_document("doc1", "Goodbye, world!") + /// .build() + /// .await + /// .expect("Failed to embed documents"); + /// ``` + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder { + embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + } +} + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} diff --git a/rig-core/src/providers/aliyun/embedding.rs b/rig-core/src/providers/aliyun/embedding.rs new file mode 100644 index 0000000..bbc765d --- /dev/null +++ b/rig-core/src/providers/aliyun/embedding.rs @@ -0,0 +1,239 @@ +// ================================================================ +//! Aliyun Embedding API Integration +//! Implementation of Aliyun embedding models for text vectorization +//! From https://help.aliyun.com/zh/model-studio/developer-reference/text-embedding-synchronous-api +// ================================================================ + +use serde_json::json; + +use crate::embeddings::{self, EmbeddingError}; + +use super::client::{ApiResponse, Client}; + +// Available embedding models provided by Aliyun +pub const EMBEDDING_V1: &str = "text-embedding-v1"; +pub const EMBEDDING_V2: &str = "text-embedding-v2"; +pub const EMBEDDING_V3: &str = "text-embedding-v3"; + +/// Aliyun embedding model implementation +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + model: String, + ndims: Option, +} + +impl EmbeddingModel { + /// Creates a new instance of the Aliyun embedding model + /// + /// # Arguments + /// * `client` - The Aliyun API client + /// * `model` - The model identifier (e.g., "text-embedding-v1") + /// * `ndims` - Optional custom dimension size for the embedding output + pub fn new(client: Client, model: &str, ndims: Option) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} + +impl EmbeddingModel { + /// Returns the maximum number of documents supported by the model + /// + /// # Returns + /// * For EMBEDDING_V3 model: 10 documents maximum + /// * For other models: 25 documents maximum + fn max_documents(&self) -> usize { + match self.model.as_str() { + EMBEDDING_V3 => 10, + _ => 25, + } + } + + /// Returns the maximum number of tokens per document supported by the model + /// + /// # Returns + /// * For EMBEDDING_V3 model: 8192 tokens per document + /// * For other models: 2048 tokens per document + fn max_tokens(&self) -> usize { + match self.model.as_str() { + EMBEDDING_V3 => 8192, + _ => 2048, + } + } + + /// Validates if the document list meets the model's constraints + /// + /// # Validation Checks + /// 1. Number of documents doesn't exceed model's maximum capacity + /// 2. Each document's token count is within the model's token limit + /// + /// # Returns + /// * `Ok(())` if validation passes + /// * `Err(EmbeddingError)` with appropriate error message if validation fails + fn validate_documents(&self, documents: &[String]) -> Result<(), EmbeddingError> { + const AVG_CHARS_PER_TOKEN: usize = 4; + + if documents.len() > self.max_documents() { + return Err(EmbeddingError::ProviderError(format!( + "Model {} supports maximum {} documents", + self.model, + self.max_documents() + ))); + } + + for (i, doc) in documents.iter().enumerate() { + let estimated_tokens = doc.len() / AVG_CHARS_PER_TOKEN; + if estimated_tokens > self.max_tokens() { + return Err(EmbeddingError::ProviderError(format!( + "Document #{} exceeds maximum token limit of {}", + i + 1, + self.max_tokens() + ))); + } + } + + Ok(()) + } +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 25; + + /// Returns the default embedding dimension for the current model + /// + /// # Returns + /// * EMBEDDING_V1: 1536 dimensions + /// * EMBEDDING_V2: 1536 dimensions + /// * EMBEDDING_V3: 1024 dimensions (can be customized) + /// * Unknown models: 0 dimensions + fn ndims(&self) -> usize { + match self.model.as_str() { + EMBEDDING_V1 => 1536, + EMBEDDING_V2 => 1536, + + // V3 model defaults to 1024 dimensions + // Can be customized to [128, 256, 384, 512, 768, 1024] + EMBEDDING_V3 => 1024, + _ => 0, // Default to 0 for unknown models + } + } + + /// Generates embeddings for the provided text documents + /// + /// # Arguments + /// * `documents` - Collection of text documents to embed + /// + /// # Returns + /// * `Result, EmbeddingError>` - Vector of embeddings or error + #[cfg_attr(feature = "worker", worker::send)] + async fn embed_texts( + &self, + documents: impl IntoIterator + Send, + ) -> Result, EmbeddingError> { + let documents: Vec = documents.into_iter().collect(); + + self.validate_documents(&documents)?; + + let request = json!({ + "model": self.model, + "input": documents, + "dimension": self.ndims.unwrap_or(self.ndims()), + "encoding_format": "float", + }); + + tracing::info!("{}", serde_json::to_string_pretty(&request).unwrap()); + + let response = self + .client + .post(&format!("/compatible-mode/v1/embeddings")) + .json(&request) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(response) => { + let docs = documents + .into_iter() + .zip(response.data) + .map(|(document, embedding)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect(); + + Ok(docs) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +// ================================================================= +// Aliyun API Types +// ================================================================= +/// Type definitions for Aliyun Embedding API responses +/// Follows OpenAI-compatible API structure +#[allow(dead_code)] +mod aliyun_api_types { + use serde::Deserialize; + + /// Response structure for embedding requests + #[derive(Debug, Deserialize)] + pub struct EmbeddingResponse { + pub data: Vec, + pub model: String, + pub object: String, + pub usage: Usage, + pub id: String, + } + + /// Individual embedding data for a single input document + #[derive(Debug, Deserialize)] + pub struct EmbeddingData { + pub embedding: Vec, + pub index: usize, + pub object: String, + } + + /// Token usage statistics for the embedding request + #[derive(Debug, Deserialize)] + pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::embeddings::embedding::EmbeddingModel as EmbeddingModelTrait; + + #[tokio::test] + async fn test_embed_texts() { + let client = Client::from_env(); + let model = EmbeddingModel::new(client, EMBEDDING_V1, None); + + // Test embedding for a single document + let documents = vec!["Hello, world!".to_string()]; + let embeddings = model.embed_texts(documents).await.unwrap(); + + assert_eq!(embeddings.len(), 1); + assert_eq!(embeddings[0].vec.len(), 1536); + + // Test embedding for multiple documents + let documents = vec!["Hello, world!".to_string(), "This is a test".to_string()]; + let embeddings = model.embed_texts(documents).await.unwrap(); + + assert_eq!(embeddings.len(), 2); + assert_eq!(embeddings[0].vec.len(), 1536); + assert_eq!(embeddings[0].document, "Hello, world!"); + assert_eq!(embeddings[1].vec.len(), 1536); + assert_eq!(embeddings[1].document, "This is a test"); + } +} diff --git a/rig-core/src/providers/aliyun/mod.rs b/rig-core/src/providers/aliyun/mod.rs new file mode 100644 index 0000000..95e380a --- /dev/null +++ b/rig-core/src/providers/aliyun/mod.rs @@ -0,0 +1,4 @@ +pub mod client; +pub mod embedding; + +pub use client::Client; diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 99f7a94..04c8b21 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -62,3 +62,4 @@ pub mod openrouter; pub mod perplexity; pub mod together; pub mod xai; +pub mod aliyun;