From 01c93176cfb6ad263f7022c3da8a0ffedeb12427 Mon Sep 17 00:00:00 2001 From: Collin Brittain Date: Fri, 11 Apr 2025 14:03:52 -0500 Subject: [PATCH] Simplify inception provider module to use openai streaming api --- rig-core/examples/agent_with_inception.rs | 7 +- rig-core/examples/inception_streaming.rs | 2 +- rig-core/src/providers/inception.rs | 325 ++++++++++++++++++ rig-core/src/providers/inception/client.rs | 91 ----- .../src/providers/inception/completion.rs | 197 ----------- rig-core/src/providers/inception/mod.rs | 6 - rig-core/src/providers/inception/streaming.rs | 121 ------- 7 files changed, 328 insertions(+), 421 deletions(-) create mode 100644 rig-core/src/providers/inception.rs delete mode 100644 rig-core/src/providers/inception/client.rs delete mode 100644 rig-core/src/providers/inception/completion.rs delete mode 100644 rig-core/src/providers/inception/mod.rs delete mode 100644 rig-core/src/providers/inception/streaming.rs diff --git a/rig-core/examples/agent_with_inception.rs b/rig-core/examples/agent_with_inception.rs index c740295..9c140bf 100644 --- a/rig-core/examples/agent_with_inception.rs +++ b/rig-core/examples/agent_with_inception.rs @@ -2,21 +2,18 @@ use std::env; use rig::{ completion::Prompt, - providers::inception::{ClientBuilder, MERCURY_CODER_SMALL}, + providers::inception::{Client, MERCURY_CODER_SMALL}, }; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create Inception Labs client - let client = - ClientBuilder::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set")) - .build(); + let client = Client::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set")); // Create agent with a single context prompt let agent = client .agent(MERCURY_CODER_SMALL) .preamble("You are a helpful AI assistant.") - .temperature(0.0) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/inception_streaming.rs b/rig-core/examples/inception_streaming.rs index 1e63f74..9fb8793 100644 --- a/rig-core/examples/inception_streaming.rs +++ b/rig-core/examples/inception_streaming.rs @@ -1,5 +1,5 @@ use rig::{ - providers::inception::{self, completion::MERCURY_CODER_SMALL}, + providers::inception::{self, MERCURY_CODER_SMALL}, streaming::{stream_to_stdout, StreamingPrompt}, }; diff --git a/rig-core/src/providers/inception.rs b/rig-core/src/providers/inception.rs new file mode 100644 index 0000000..1dbaa73 --- /dev/null +++ b/rig-core/src/providers/inception.rs @@ -0,0 +1,325 @@ +use super::openai::send_compatible_streaming_request; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + extractor::ExtractorBuilder, + json_utils::{self, merge_inplace}, + message::{self, MessageError}, + streaming::{StreamingCompletionModel, StreamingResult}, + OneOrMany, +}; + +const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str) -> Self { + Self { + base_url: INCEPTION_API_BASE_URL.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Content-Type", + "application/json" + .parse() + .expect("Content-Type should parse"), + ); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Authorization should parse"), + ); + headers + }) + .build() + .expect("Inception reqwest client should build"), + } + } + + pub fn from_env() -> Self { + let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"); + Client::new(&api_key) + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +// ================================================================ +// Inception Completion API +// ================================================================ +/// `mercury-coder-small` completion model +pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub object: String, + pub created: u64, + pub model: String, + pub usage: Usage, +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +impl From for CompletionError { + fn from(err: ApiErrorResponse) -> Self { + CompletionError::ProviderError(err.message) + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + self.prompt_tokens, self.completion_tokens, self.total_tokens + ) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Message { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + +impl TryFrom for Message { + type Error = MessageError; + + fn try_from(message: message::Message) -> Result { + Ok(match message { + message::Message::User { content } => Message { + role: Role::User, + content: match content.first() { + message::UserContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "User message content must be a text message".to_string(), + )) + } + }, + }, + message::Message::Assistant { content } => Message { + role: Role::Assistant, + content: match content.first() { + message::AssistantContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "Assistant message content must be a text message".to_string(), + )) + } + }, + }, + }) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: CompletionResponse) -> Result { + let choice = response.choices.first().ok_or_else(|| { + CompletionError::ResponseError("Response contained no choices".to_owned()) + })?; + + let content = match &choice.message.role { + Role::Assistant => { + let content = completion::AssistantContent::text(&choice.message.content); + + Ok(content) + } + _ => Err(CompletionError::ResponseError( + "Response did not contain a valid message".into(), + )), + }?; + + let choice = OneOrMany::one(content); + + Ok(completion::CompletionResponse { + choice, + raw_response: response, + }) + } +} + +const MAX_TOKENS: u64 = 8192; + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) + pub model: String, +} + +impl CompletionModel { + pub(crate) fn create_completion_request( + &self, + completion_request: CompletionRequest, + ) -> Result { + let mut messages = vec![]; + + if let Some(preamble) = completion_request.preamble.clone() { + messages.push(Message { + role: Role::System, + content: preamble.clone(), + }); + } + + let prompt_message: Message = completion_request + .prompt_with_context() + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; + + let chat_history = completion_request + .chat_history + .into_iter() + .map(|message| { + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into())) + }) + .collect::, _>>()?; + + messages.extend(chat_history); + messages.push(prompt_message); + + let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS); + + let request = json!({ + "model": self.model, + "messages": messages, + // The beta API reference doesn't mention temperature but it doesn't hurt to include it + "temperature": completion_request.temperature, + "max_tokens": max_tokens, + }); + + let request = if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }; + + Ok(request) + } +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let request = self.create_completion_request(completion_request)?; + + let response = self + .client + .post("/chat/completions") + .json(&request) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Inception completion token usage: {}", + response.usage + ); + + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} + +impl StreamingCompletionModel for CompletionModel { + async fn stream( + &self, + completion_request: CompletionRequest, + ) -> Result { + let mut request = self.create_completion_request(completion_request)?; + + merge_inplace(&mut request, json!({"stream": true})); + + let builder = self.client.post("/chat/completions").json(&request); + + send_compatible_streaming_request(builder).await + } +} diff --git a/rig-core/src/providers/inception/client.rs b/rig-core/src/providers/inception/client.rs deleted file mode 100644 index 7871b08..0000000 --- a/rig-core/src/providers/inception/client.rs +++ /dev/null @@ -1,91 +0,0 @@ -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -use crate::{ - agent::AgentBuilder, extractor::ExtractorBuilder, - providers::inception::completion::CompletionModel, -}; - -const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1"; - -#[derive(Clone)] -pub struct ClientBuilder<'a> { - api_key: &'a str, - base_url: &'a str, -} - -impl<'a> ClientBuilder<'a> { - pub fn new(api_key: &'a str) -> Self { - Self { - api_key, - base_url: INCEPTION_API_BASE_URL, - } - } - - pub fn base_url(mut self, base_url: &'a str) -> Self { - self.base_url = base_url; - self - } - - pub fn build(self) -> Client { - Client::new(self.api_key, self.base_url) - } -} - -#[derive(Clone)] -pub struct Client { - base_url: String, - http_client: reqwest::Client, -} - -impl Client { - pub fn new(api_key: &str, base_url: &str) -> Self { - Self { - base_url: base_url.to_string(), - http_client: reqwest::Client::builder() - .default_headers({ - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "Content-Type", - "application/json" - .parse() - .expect("Content-Type should parse"), - ); - headers.insert( - "Authorization", - format!("Bearer {}", api_key) - .parse() - .expect("Authorization should parse"), - ); - headers - }) - .build() - .expect("Inception reqwest client should build"), - } - } - - pub fn from_env() -> Self { - let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"); - ClientBuilder::new(&api_key).build() - } - - pub fn post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url) - } - - pub fn completion_model(&self, model: &str) -> CompletionModel { - CompletionModel::new(self.clone(), model) - } - - pub fn agent(&self, model: &str) -> AgentBuilder { - AgentBuilder::new(self.completion_model(model)) - } - - pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( - &self, - model: &str, - ) -> ExtractorBuilder { - ExtractorBuilder::new(self.completion_model(model)) - } -} diff --git a/rig-core/src/providers/inception/completion.rs b/rig-core/src/providers/inception/completion.rs deleted file mode 100644 index 18f19dd..0000000 --- a/rig-core/src/providers/inception/completion.rs +++ /dev/null @@ -1,197 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::json; - -use crate::{ - completion::{self, CompletionError}, - message::{self, MessageError}, - OneOrMany, -}; - -use super::client::Client; - -// ================================================================ -// Inception Completion API -// ================================================================ -/// `mercury-coder-small` completion model -pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small"; - -#[derive(Debug, Deserialize)] -pub struct CompletionResponse { - pub id: String, - pub choices: Vec, - pub object: String, - pub created: u64, - pub model: String, - pub usage: Usage, -} - -#[derive(Debug, Deserialize)] -pub struct Choice { - pub index: usize, - pub message: Message, - pub finish_reason: String, -} - -impl From for completion::AssistantContent { - fn from(choice: Choice) -> Self { - completion::AssistantContent::from(&choice) - } -} - -impl From<&Choice> for completion::AssistantContent { - fn from(choice: &Choice) -> Self { - completion::AssistantContent::Text(completion::message::Text { - text: choice.message.content.clone(), - }) - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct Message { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, -} - -#[derive(Debug, Deserialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -impl std::fmt::Display for Usage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", - self.prompt_tokens, self.completion_tokens, self.total_tokens - ) - } -} - -impl TryFrom for Message { - type Error = MessageError; - - fn try_from(message: message::Message) -> Result { - Ok(match message { - message::Message::User { content } => Message { - role: Role::User, - content: match content.first() { - message::UserContent::Text(message::Text { text }) => text.clone(), - _ => { - return Err(MessageError::ConversionError( - "User message content must be a text message".to_string(), - )) - } - }, - }, - message::Message::Assistant { content } => Message { - role: Role::Assistant, - content: match content.first() { - message::AssistantContent::Text(message::Text { text }) => text.clone(), - _ => { - return Err(MessageError::ConversionError( - "Assistant message content must be a text message".to_string(), - )) - } - }, - }, - }) - } -} - -impl TryFrom for completion::CompletionResponse { - type Error = CompletionError; - - fn try_from(response: CompletionResponse) -> Result { - let content = response.choices.iter().map(Into::into).collect::>(); - - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; - - Ok(completion::CompletionResponse { - choice, - raw_response: response, - }) - } -} - -const MAX_TOKENS: u64 = 8192; - -#[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, - pub model: String, -} - -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_string(), - } - } -} - -impl completion::CompletionModel for CompletionModel { - type Response = CompletionResponse; - - #[cfg_attr(feature = "worker", worker::send)] - async fn completion( - &self, - completion_request: completion::CompletionRequest, - ) -> Result, CompletionError> { - let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS); - - let prompt_message: Message = completion_request - .prompt_with_context() - .try_into() - .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; - - let mut messages = completion_request - .chat_history - .into_iter() - .map(|message| { - message - .try_into() - .map_err(|e: MessageError| CompletionError::RequestError(e.into())) - }) - .collect::, _>>()?; - - messages.push(prompt_message); - - let request = json!({ - "model": self.model, - "messages": messages, - "max_tokens": max_tokens, - }); - - let response = self - .client - .post("/chat/completions") - .json(&request) - .send() - .await?; - - if response.status().is_success() { - let response = response.json::().await?; - tracing::info!(target: "rig", - "Inception completion token usage: {}", - response.usage - ); - Ok(response.try_into()?) - } else { - Err(CompletionError::ProviderError(response.text().await?)) - } - } -} diff --git a/rig-core/src/providers/inception/mod.rs b/rig-core/src/providers/inception/mod.rs deleted file mode 100644 index 08fec85..0000000 --- a/rig-core/src/providers/inception/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod client; -pub mod completion; -pub mod streaming; - -pub use client::{Client, ClientBuilder}; -pub use completion::MERCURY_CODER_SMALL; diff --git a/rig-core/src/providers/inception/streaming.rs b/rig-core/src/providers/inception/streaming.rs deleted file mode 100644 index ed3eb0a..0000000 --- a/rig-core/src/providers/inception/streaming.rs +++ /dev/null @@ -1,121 +0,0 @@ -use async_stream::stream; -use futures::StreamExt; -use serde::Deserialize; -use serde_json::json; - -use super::completion::{CompletionModel, Message}; -use crate::completion::{CompletionError, CompletionRequest}; -use crate::json_utils::merge_inplace; -use crate::message::MessageError; -use crate::providers::anthropic::decoders::sse::from_response as sse_from_response; -use crate::streaming::{self, StreamingCompletionModel, StreamingResult}; - -#[derive(Debug, Deserialize)] -pub struct StreamingResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct StreamingChoice { - pub index: usize, - pub delta: Delta, - pub finish_reason: Option, -} - -#[derive(Debug, Deserialize)] -pub struct Delta { - pub content: Option, - pub role: Option, -} - -impl StreamingCompletionModel for CompletionModel { - async fn stream( - &self, - completion_request: CompletionRequest, - ) -> Result { - let prompt_message: Message = completion_request - .prompt_with_context() - .try_into() - .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; - - let mut messages = completion_request - .chat_history - .into_iter() - .map(|message| { - message - .try_into() - .map_err(|e: MessageError| CompletionError::RequestError(e.into())) - }) - .collect::, _>>()?; - - messages.push(prompt_message); - - let mut request = json!({ - "model": self.model, - "messages": messages, - "max_tokens": completion_request.max_tokens.unwrap_or(8192), - "stream": true, - }); - - if let Some(temperature) = completion_request.temperature { - merge_inplace(&mut request, json!({ "temperature": temperature })); - } - - if let Some(ref params) = completion_request.additional_params { - merge_inplace(&mut request, params.clone()) - } - - let response = self - .client - .post("chat/completions") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - return Err(CompletionError::ProviderError(response.text().await?)); - } - - // Use our SSE decoder to directly handle Server-Sent Events format - let sse_stream = sse_from_response(response); - - Ok(Box::pin(stream! { - let mut sse_stream = Box::pin(sse_stream); - - while let Some(sse_result) = sse_stream.next().await { - match sse_result { - Ok(sse) => { - // Parse the SSE data as a StreamingResponse - match serde_json::from_str::(&sse.data) { - Ok(response) => { - if let Some(choice) = response.choices.first() { - if let Some(content) = &choice.delta.content { - yield Ok(streaming::StreamingChoice::Message(content.clone())); - } - if choice.finish_reason.as_deref() == Some("stop") { - break; - } - } - }, - Err(e) => { - if !sse.data.trim().is_empty() { - yield Err(CompletionError::ResponseError( - format!("Failed to parse JSON: {} (Data: {})", e, sse.data) - )); - } - } - } - }, - Err(e) => { - yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e))); - break; - } - } - } - })) - } -}