mirror of https://github.com/0xplaygrounds/rig
Merge c97b0c7548
into 33e8fc7a65
This commit is contained in:
commit
31ba7a5968
|
@ -127,3 +127,8 @@ required-features = ["audio"]
|
|||
[[example]]
|
||||
name = "hyperbolic_audio_generation"
|
||||
required-features = ["audio"]
|
||||
|
||||
|
||||
[[example]]
|
||||
name = "aliyun_embeddings"
|
||||
required-features = ["derive"]
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
use rig::providers::aliyun;
|
||||
use rig::Embed;
|
||||
|
||||
#[derive(Embed, Debug)]
|
||||
struct Greetings {
|
||||
#[embed]
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Initialize the Aliyun client
|
||||
let client = aliyun::Client::from_env();
|
||||
|
||||
let embeddings = client
|
||||
.embeddings(aliyun::embedding::EMBEDDING_V1)
|
||||
.document(Greetings {
|
||||
message: "Hello, world!".to_string(),
|
||||
})?
|
||||
.document(Greetings {
|
||||
message: "Goodbye, world!".to_string(),
|
||||
})?
|
||||
.build()
|
||||
.await
|
||||
.expect("Failed to embed documents");
|
||||
|
||||
println!("{:?}", embeddings);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
use crate::{
|
||||
embeddings::{self},
|
||||
Embed,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::embedding::EmbeddingModel;
|
||||
|
||||
// ================================================================
|
||||
// Aliyun Gemini Client
|
||||
// ================================================================
|
||||
const ALIYUN_API_BASE_URL: &str = "https://dashscope.aliyuncs.com";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new Aliyun client with the given API key.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::Client;
|
||||
///
|
||||
/// // Initialize the Aliyun client
|
||||
/// let aliyun = Client::new("your-dashscope-api-key");
|
||||
/// ```
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, ALIYUN_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new Aliyun client with the given API key and base URL.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::Client;
|
||||
///
|
||||
/// // Initialize the Aliyun client with a custom base URL
|
||||
/// let aliyun = Client::from_url("your-dashscope-api-key", "https://custom-dashscope-url.com");
|
||||
/// ```
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Aliyun reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Aliyun client from the `DASHSCOPE_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::Client;
|
||||
///
|
||||
/// // Initialize the Aliyun client from environment variable
|
||||
/// let aliyun = Client::from_env();
|
||||
/// ```
|
||||
/// # Panics
|
||||
/// This function will panic if the `DASHSCOPE_API_KEY` environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
/// Create a POST request to the specified API endpoint path.
|
||||
/// The Authorization header with the API key will be automatically added.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - The API endpoint path to append to the base URL
|
||||
///
|
||||
/// # Returns
|
||||
/// A reqwest::RequestBuilder instance that can be further customized before sending
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
|
||||
tracing::debug!("POST {}/{}", self.base_url, path);
|
||||
self.http_client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name.
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::{Client, self};
|
||||
///
|
||||
/// // Initialize the Aliyun client
|
||||
/// let aliyun = Client::new("your-dashscope-api-key");
|
||||
///
|
||||
/// let embedding_model = aliyun.embedding_model("your-model-name");
|
||||
/// ```
|
||||
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, None)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::{Client, self};
|
||||
///
|
||||
/// // Initialize the Aliyun client
|
||||
/// let aliyun = Client::new("your-dashscope-api-key");
|
||||
///
|
||||
/// let embedding_model = aliyun.embedding_model_with_ndims("model-unknown-to-rig", 1024);
|
||||
/// ```
|
||||
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, Some(ndims))
|
||||
}
|
||||
|
||||
/// Create an embedding builder with the given embedding model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::aliyun::{Client, self};
|
||||
///
|
||||
/// // Initialize the Aliyun client
|
||||
/// let aliyun = Client::new("your-dashscope-api-key");
|
||||
///
|
||||
/// let embeddings = aliyun.embeddings("your-model-name")
|
||||
/// .simple_document("doc0", "Hello, world!")
|
||||
/// .simple_document("doc1", "Goodbye, world!")
|
||||
/// .build()
|
||||
/// .await
|
||||
/// .expect("Failed to embed documents");
|
||||
/// ```
|
||||
pub fn embeddings<D: Embed>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
// ================================================================
|
||||
//! Aliyun Embedding API Integration
|
||||
//! Implementation of Aliyun embedding models for text vectorization
|
||||
//! From https://help.aliyun.com/zh/model-studio/developer-reference/text-embedding-synchronous-api
|
||||
// ================================================================
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::embeddings::{self, EmbeddingError};
|
||||
|
||||
use super::client::{ApiResponse, Client};
|
||||
|
||||
// Available embedding models provided by Aliyun
|
||||
pub const EMBEDDING_V1: &str = "text-embedding-v1";
|
||||
pub const EMBEDDING_V2: &str = "text-embedding-v2";
|
||||
pub const EMBEDDING_V3: &str = "text-embedding-v3";
|
||||
|
||||
/// Aliyun embedding model implementation
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
client: Client,
|
||||
model: String,
|
||||
ndims: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
/// Creates a new instance of the Aliyun embedding model
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client` - The Aliyun API client
|
||||
/// * `model` - The model identifier (e.g., "text-embedding-v1")
|
||||
/// * `ndims` - Optional custom dimension size for the embedding output
|
||||
pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
/// Returns the maximum number of documents supported by the model
|
||||
///
|
||||
/// # Returns
|
||||
/// * For EMBEDDING_V3 model: 10 documents maximum
|
||||
/// * For other models: 25 documents maximum
|
||||
fn max_documents(&self) -> usize {
|
||||
match self.model.as_str() {
|
||||
EMBEDDING_V3 => 10,
|
||||
_ => 25,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum number of tokens per document supported by the model
|
||||
///
|
||||
/// # Returns
|
||||
/// * For EMBEDDING_V3 model: 8192 tokens per document
|
||||
/// * For other models: 2048 tokens per document
|
||||
fn max_tokens(&self) -> usize {
|
||||
match self.model.as_str() {
|
||||
EMBEDDING_V3 => 8192,
|
||||
_ => 2048,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validates if the document list meets the model's constraints
|
||||
///
|
||||
/// # Validation Checks
|
||||
/// 1. Number of documents doesn't exceed model's maximum capacity
|
||||
/// 2. Each document's token count is within the model's token limit
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(())` if validation passes
|
||||
/// * `Err(EmbeddingError)` with appropriate error message if validation fails
|
||||
fn validate_documents(&self, documents: &[String]) -> Result<(), EmbeddingError> {
|
||||
const AVG_CHARS_PER_TOKEN: usize = 4;
|
||||
|
||||
if documents.len() > self.max_documents() {
|
||||
return Err(EmbeddingError::ProviderError(format!(
|
||||
"Model {} supports maximum {} documents",
|
||||
self.model,
|
||||
self.max_documents()
|
||||
)));
|
||||
}
|
||||
|
||||
for (i, doc) in documents.iter().enumerate() {
|
||||
let estimated_tokens = doc.len() / AVG_CHARS_PER_TOKEN;
|
||||
if estimated_tokens > self.max_tokens() {
|
||||
return Err(EmbeddingError::ProviderError(format!(
|
||||
"Document #{} exceeds maximum token limit of {}",
|
||||
i + 1,
|
||||
self.max_tokens()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 25;
|
||||
|
||||
/// Returns the default embedding dimension for the current model
|
||||
///
|
||||
/// # Returns
|
||||
/// * EMBEDDING_V1: 1536 dimensions
|
||||
/// * EMBEDDING_V2: 1536 dimensions
|
||||
/// * EMBEDDING_V3: 1024 dimensions (can be customized)
|
||||
/// * Unknown models: 0 dimensions
|
||||
fn ndims(&self) -> usize {
|
||||
match self.model.as_str() {
|
||||
EMBEDDING_V1 => 1536,
|
||||
EMBEDDING_V2 => 1536,
|
||||
|
||||
// V3 model defaults to 1024 dimensions
|
||||
// Can be customized to [128, 256, 384, 512, 768, 1024]
|
||||
EMBEDDING_V3 => 1024,
|
||||
_ => 0, // Default to 0 for unknown models
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates embeddings for the provided text documents
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `documents` - Collection of text documents to embed
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<Embedding>, EmbeddingError>` - Vector of embeddings or error
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String> + Send,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let documents: Vec<String> = documents.into_iter().collect();
|
||||
|
||||
self.validate_documents(&documents)?;
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"input": documents,
|
||||
"dimension": self.ndims.unwrap_or(self.ndims()),
|
||||
"encoding_format": "float",
|
||||
});
|
||||
|
||||
tracing::info!("{}", serde_json::to_string_pretty(&request).unwrap());
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("/compatible-mode/v1/embeddings"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json::<ApiResponse<aliyun_api_types::EmbeddingResponse>>()
|
||||
.await?;
|
||||
|
||||
match response {
|
||||
ApiResponse::Ok(response) => {
|
||||
let docs = documents
|
||||
.into_iter()
|
||||
.zip(response.data)
|
||||
.map(|(document, embedding)| embeddings::Embedding {
|
||||
document,
|
||||
vec: embedding.embedding,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(docs)
|
||||
}
|
||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =================================================================
|
||||
// Aliyun API Types
|
||||
// =================================================================
|
||||
/// Type definitions for Aliyun Embedding API responses
|
||||
/// Follows OpenAI-compatible API structure
|
||||
#[allow(dead_code)]
|
||||
mod aliyun_api_types {
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Response structure for embedding requests
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub data: Vec<EmbeddingData>,
|
||||
pub model: String,
|
||||
pub object: String,
|
||||
pub usage: Usage,
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Individual embedding data for a single input document
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingData {
|
||||
pub embedding: Vec<f64>,
|
||||
pub index: usize,
|
||||
pub object: String,
|
||||
}
|
||||
|
||||
/// Token usage statistics for the embedding request
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::embeddings::embedding::EmbeddingModel as EmbeddingModelTrait;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embed_texts() {
|
||||
let client = Client::from_env();
|
||||
let model = EmbeddingModel::new(client, EMBEDDING_V1, None);
|
||||
|
||||
// Test embedding for a single document
|
||||
let documents = vec!["Hello, world!".to_string()];
|
||||
let embeddings = model.embed_texts(documents).await.unwrap();
|
||||
|
||||
assert_eq!(embeddings.len(), 1);
|
||||
assert_eq!(embeddings[0].vec.len(), 1536);
|
||||
|
||||
// Test embedding for multiple documents
|
||||
let documents = vec!["Hello, world!".to_string(), "This is a test".to_string()];
|
||||
let embeddings = model.embed_texts(documents).await.unwrap();
|
||||
|
||||
assert_eq!(embeddings.len(), 2);
|
||||
assert_eq!(embeddings[0].vec.len(), 1536);
|
||||
assert_eq!(embeddings[0].document, "Hello, world!");
|
||||
assert_eq!(embeddings[1].vec.len(), 1536);
|
||||
assert_eq!(embeddings[1].document, "This is a test");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
pub mod client;
|
||||
pub mod embedding;
|
||||
|
||||
pub use client::Client;
|
|
@ -62,3 +62,4 @@ pub mod openrouter;
|
|||
pub mod perplexity;
|
||||
pub mod together;
|
||||
pub mod xai;
|
||||
pub mod aliyun;
|
||||
|
|
Loading…
Reference in New Issue