This commit is contained in:
mztlive 2025-04-19 01:53:28 +00:00 committed by GitHub
commit 31ba7a5968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 441 additions and 0 deletions

View File

@ -127,3 +127,8 @@ required-features = ["audio"]
[[example]] [[example]]
name = "hyperbolic_audio_generation" name = "hyperbolic_audio_generation"
required-features = ["audio"] required-features = ["audio"]
[[example]]
name = "aliyun_embeddings"
required-features = ["derive"]

View File

@ -0,0 +1,30 @@
use rig::providers::aliyun;
use rig::Embed;
#[derive(Embed, Debug)]
struct Greetings {
#[embed]
message: String,
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the Aliyun client
let client = aliyun::Client::from_env();
let embeddings = client
.embeddings(aliyun::embedding::EMBEDDING_V1)
.document(Greetings {
message: "Hello, world!".to_string(),
})?
.document(Greetings {
message: "Goodbye, world!".to_string(),
})?
.build()
.await
.expect("Failed to embed documents");
println!("{:?}", embeddings);
Ok(())
}

View File

@ -0,0 +1,162 @@
use crate::{
embeddings::{self},
Embed,
};
use serde::Deserialize;
use super::embedding::EmbeddingModel;
// ================================================================
// Aliyun Gemini Client
// ================================================================
const ALIYUN_API_BASE_URL: &str = "https://dashscope.aliyuncs.com";
#[derive(Clone)]
pub struct Client {
base_url: String,
api_key: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new Aliyun client with the given API key.
///
/// # Example
/// ```
/// use rig::providers::aliyun::Client;
///
/// // Initialize the Aliyun client
/// let aliyun = Client::new("your-dashscope-api-key");
/// ```
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, ALIYUN_API_BASE_URL)
}
/// Create a new Aliyun client with the given API key and base URL.
///
/// # Example
/// ```
/// use rig::providers::aliyun::Client;
///
/// // Initialize the Aliyun client with a custom base URL
/// let aliyun = Client::from_url("your-dashscope-api-key", "https://custom-dashscope-url.com");
/// ```
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.expect("Aliyun reqwest client should build"),
}
}
/// Create a new Aliyun client from the `DASHSCOPE_API_KEY` environment variable.
/// Panics if the environment variable is not set.
///
/// # Example
/// ```
/// use rig::providers::aliyun::Client;
///
/// // Initialize the Aliyun client from environment variable
/// let aliyun = Client::from_env();
/// ```
/// # Panics
/// This function will panic if the `DASHSCOPE_API_KEY` environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY not set");
Self::new(&api_key)
}
/// Create a POST request to the specified API endpoint path.
/// The Authorization header with the API key will be automatically added.
///
/// # Arguments
/// * `path` - The API endpoint path to append to the base URL
///
/// # Returns
/// A reqwest::RequestBuilder instance that can be further customized before sending
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
tracing::debug!("POST {}/{}", self.base_url, path);
self.http_client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
}
/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::aliyun::{Client, self};
///
/// // Initialize the Aliyun client
/// let aliyun = Client::new("your-dashscope-api-key");
///
/// let embedding_model = aliyun.embedding_model("your-model-name");
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, None)
}
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
///
/// # Example
/// ```
/// use rig::providers::aliyun::{Client, self};
///
/// // Initialize the Aliyun client
/// let aliyun = Client::new("your-dashscope-api-key");
///
/// let embedding_model = aliyun.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}
/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::aliyun::{Client, self};
///
/// // Initialize the Aliyun client
/// let aliyun = Client::new("your-dashscope-api-key");
///
/// let embeddings = aliyun.embeddings("your-model-name")
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings<D: Embed>(
&self,
model: &str,
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

View File

@ -0,0 +1,239 @@
// ================================================================
//! Aliyun Embedding API Integration
//! Implementation of Aliyun embedding models for text vectorization
//! From https://help.aliyun.com/zh/model-studio/developer-reference/text-embedding-synchronous-api
// ================================================================
use serde_json::json;
use crate::embeddings::{self, EmbeddingError};
use super::client::{ApiResponse, Client};
// Available embedding models provided by Aliyun
pub const EMBEDDING_V1: &str = "text-embedding-v1";
pub const EMBEDDING_V2: &str = "text-embedding-v2";
pub const EMBEDDING_V3: &str = "text-embedding-v3";
/// Aliyun embedding model implementation
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
model: String,
ndims: Option<usize>,
}
impl EmbeddingModel {
/// Creates a new instance of the Aliyun embedding model
///
/// # Arguments
/// * `client` - The Aliyun API client
/// * `model` - The model identifier (e.g., "text-embedding-v1")
/// * `ndims` - Optional custom dimension size for the embedding output
pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}
impl EmbeddingModel {
/// Returns the maximum number of documents supported by the model
///
/// # Returns
/// * For EMBEDDING_V3 model: 10 documents maximum
/// * For other models: 25 documents maximum
fn max_documents(&self) -> usize {
match self.model.as_str() {
EMBEDDING_V3 => 10,
_ => 25,
}
}
/// Returns the maximum number of tokens per document supported by the model
///
/// # Returns
/// * For EMBEDDING_V3 model: 8192 tokens per document
/// * For other models: 2048 tokens per document
fn max_tokens(&self) -> usize {
match self.model.as_str() {
EMBEDDING_V3 => 8192,
_ => 2048,
}
}
/// Validates if the document list meets the model's constraints
///
/// # Validation Checks
/// 1. Number of documents doesn't exceed model's maximum capacity
/// 2. Each document's token count is within the model's token limit
///
/// # Returns
/// * `Ok(())` if validation passes
/// * `Err(EmbeddingError)` with appropriate error message if validation fails
fn validate_documents(&self, documents: &[String]) -> Result<(), EmbeddingError> {
const AVG_CHARS_PER_TOKEN: usize = 4;
if documents.len() > self.max_documents() {
return Err(EmbeddingError::ProviderError(format!(
"Model {} supports maximum {} documents",
self.model,
self.max_documents()
)));
}
for (i, doc) in documents.iter().enumerate() {
let estimated_tokens = doc.len() / AVG_CHARS_PER_TOKEN;
if estimated_tokens > self.max_tokens() {
return Err(EmbeddingError::ProviderError(format!(
"Document #{} exceeds maximum token limit of {}",
i + 1,
self.max_tokens()
)));
}
}
Ok(())
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 25;
/// Returns the default embedding dimension for the current model
///
/// # Returns
/// * EMBEDDING_V1: 1536 dimensions
/// * EMBEDDING_V2: 1536 dimensions
/// * EMBEDDING_V3: 1024 dimensions (can be customized)
/// * Unknown models: 0 dimensions
fn ndims(&self) -> usize {
match self.model.as_str() {
EMBEDDING_V1 => 1536,
EMBEDDING_V2 => 1536,
// V3 model defaults to 1024 dimensions
// Can be customized to [128, 256, 384, 512, 768, 1024]
EMBEDDING_V3 => 1024,
_ => 0, // Default to 0 for unknown models
}
}
/// Generates embeddings for the provided text documents
///
/// # Arguments
/// * `documents` - Collection of text documents to embed
///
/// # Returns
/// * `Result<Vec<Embedding>, EmbeddingError>` - Vector of embeddings or error
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<String> = documents.into_iter().collect();
self.validate_documents(&documents)?;
let request = json!({
"model": self.model,
"input": documents,
"dimension": self.ndims.unwrap_or(self.ndims()),
"encoding_format": "float",
});
tracing::info!("{}", serde_json::to_string_pretty(&request).unwrap());
let response = self
.client
.post(&format!("/compatible-mode/v1/embeddings"))
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<aliyun_api_types::EmbeddingResponse>>()
.await?;
match response {
ApiResponse::Ok(response) => {
let docs = documents
.into_iter()
.zip(response.data)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect();
Ok(docs)
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
// =================================================================
// Aliyun API Types
// =================================================================
/// Type definitions for Aliyun Embedding API responses
/// Follows OpenAI-compatible API structure
#[allow(dead_code)]
mod aliyun_api_types {
use serde::Deserialize;
/// Response structure for embedding requests
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub data: Vec<EmbeddingData>,
pub model: String,
pub object: String,
pub usage: Usage,
pub id: String,
}
/// Individual embedding data for a single input document
#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub embedding: Vec<f64>,
pub index: usize,
pub object: String,
}
/// Token usage statistics for the embedding request
#[derive(Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::embedding::EmbeddingModel as EmbeddingModelTrait;
#[tokio::test]
async fn test_embed_texts() {
let client = Client::from_env();
let model = EmbeddingModel::new(client, EMBEDDING_V1, None);
// Test embedding for a single document
let documents = vec!["Hello, world!".to_string()];
let embeddings = model.embed_texts(documents).await.unwrap();
assert_eq!(embeddings.len(), 1);
assert_eq!(embeddings[0].vec.len(), 1536);
// Test embedding for multiple documents
let documents = vec!["Hello, world!".to_string(), "This is a test".to_string()];
let embeddings = model.embed_texts(documents).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].vec.len(), 1536);
assert_eq!(embeddings[0].document, "Hello, world!");
assert_eq!(embeddings[1].vec.len(), 1536);
assert_eq!(embeddings[1].document, "This is a test");
}
}

View File

@ -0,0 +1,4 @@
pub mod client;
pub mod embedding;
pub use client::Client;

View File

@ -62,3 +62,4 @@ pub mod openrouter;
pub mod perplexity; pub mod perplexity;
pub mod together; pub mod together;
pub mod xai; pub mod xai;
pub mod aliyun;