diff --git a/rig-core/src/providers/inception/client.rs b/rig-core/src/providers/inception/client.rs new file mode 100644 index 0000000..7871b08 --- /dev/null +++ b/rig-core/src/providers/inception/client.rs @@ -0,0 +1,91 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{ + agent::AgentBuilder, extractor::ExtractorBuilder, + providers::inception::completion::CompletionModel, +}; + +const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1"; + +#[derive(Clone)] +pub struct ClientBuilder<'a> { + api_key: &'a str, + base_url: &'a str, +} + +impl<'a> ClientBuilder<'a> { + pub fn new(api_key: &'a str) -> Self { + Self { + api_key, + base_url: INCEPTION_API_BASE_URL, + } + } + + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn build(self) -> Client { + Client::new(self.api_key, self.base_url) + } +} + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Content-Type", + "application/json" + .parse() + .expect("Content-Type should parse"), + ); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Authorization should parse"), + ); + headers + }) + .build() + .expect("Inception reqwest client should build"), + } + } + + pub fn from_env() -> Self { + let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"); + ClientBuilder::new(&api_key).build() + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} diff --git a/rig-core/src/providers/inception/completion.rs b/rig-core/src/providers/inception/completion.rs new file mode 100644 index 0000000..18f19dd --- /dev/null +++ b/rig-core/src/providers/inception/completion.rs @@ -0,0 +1,197 @@ +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::{ + completion::{self, CompletionError}, + message::{self, MessageError}, + OneOrMany, +}; + +use super::client::Client; + +// ================================================================ +// Inception Completion API +// ================================================================ +/// `mercury-coder-small` completion model +pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub object: String, + pub created: u64, + pub model: String, + pub usage: Usage, +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +impl From for completion::AssistantContent { + fn from(choice: Choice) -> Self { + completion::AssistantContent::from(&choice) + } +} + +impl From<&Choice> for completion::AssistantContent { + fn from(choice: &Choice) -> Self { + completion::AssistantContent::Text(completion::message::Text { + text: choice.message.content.clone(), + }) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Message { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + self.prompt_tokens, self.completion_tokens, self.total_tokens + ) + } +} + +impl TryFrom for Message { + type Error = MessageError; + + fn try_from(message: message::Message) -> Result { + Ok(match message { + message::Message::User { content } => Message { + role: Role::User, + content: match content.first() { + message::UserContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "User message content must be a text message".to_string(), + )) + } + }, + }, + message::Message::Assistant { content } => Message { + role: Role::Assistant, + content: match content.first() { + message::AssistantContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "Assistant message content must be a text message".to_string(), + )) + } + }, + }, + }) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: CompletionResponse) -> Result { + let content = response.choices.iter().map(Into::into).collect::>(); + + let choice = OneOrMany::many(content).map_err(|_| { + CompletionError::ResponseError( + "Response contained no message or tool call (empty)".to_owned(), + ) + })?; + + Ok(completion::CompletionResponse { + choice, + raw_response: response, + }) + } +} + +const MAX_TOKENS: u64 = 8192; + +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS); + + let prompt_message: Message = completion_request + .prompt_with_context() + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; + + let mut messages = completion_request + .chat_history + .into_iter() + .map(|message| { + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into())) + }) + .collect::, _>>()?; + + messages.push(prompt_message); + + let request = json!({ + "model": self.model, + "messages": messages, + "max_tokens": max_tokens, + }); + + let response = self + .client + .post("/chat/completions") + .json(&request) + .send() + .await?; + + if response.status().is_success() { + let response = response.json::().await?; + tracing::info!(target: "rig", + "Inception completion token usage: {}", + response.usage + ); + Ok(response.try_into()?) + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} diff --git a/rig-core/src/providers/inception/mod.rs b/rig-core/src/providers/inception/mod.rs new file mode 100644 index 0000000..733191c --- /dev/null +++ b/rig-core/src/providers/inception/mod.rs @@ -0,0 +1,2 @@ +pub mod client; +pub mod completion; diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 99f7a94..fe2ee2e 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -11,6 +11,7 @@ //! - DeepSeek //! - Azure OpenAI //! - Mira +//! - Inception //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -54,6 +55,7 @@ pub mod gemini; pub mod groq; pub mod huggingface; pub mod hyperbolic; +pub mod inception; pub mod mira; pub mod moonshot; pub mod ollama;