Simplify inception provider module to use openai streaming api

This commit is contained in:
Collin Brittain 2025-04-11 14:03:52 -05:00
parent fc412bc544
commit 01c93176cf
7 changed files with 328 additions and 421 deletions

View File

@ -2,21 +2,18 @@ use std::env;
use rig::{ use rig::{
completion::Prompt, completion::Prompt,
providers::inception::{ClientBuilder, MERCURY_CODER_SMALL}, providers::inception::{Client, MERCURY_CODER_SMALL},
}; };
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), anyhow::Error> { async fn main() -> Result<(), anyhow::Error> {
// Create Inception Labs client // Create Inception Labs client
let client = let client = Client::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"));
ClientBuilder::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"))
.build();
// Create agent with a single context prompt // Create agent with a single context prompt
let agent = client let agent = client
.agent(MERCURY_CODER_SMALL) .agent(MERCURY_CODER_SMALL)
.preamble("You are a helpful AI assistant.") .preamble("You are a helpful AI assistant.")
.temperature(0.0)
.build(); .build();
// Prompt the agent and print the response // Prompt the agent and print the response

View File

@ -1,5 +1,5 @@
use rig::{ use rig::{
providers::inception::{self, completion::MERCURY_CODER_SMALL}, providers::inception::{self, MERCURY_CODER_SMALL},
streaming::{stream_to_stdout, StreamingPrompt}, streaming::{stream_to_stdout, StreamingPrompt},
}; };

View File

@ -0,0 +1,325 @@
use super::openai::send_compatible_streaming_request;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils::{self, merge_inplace},
message::{self, MessageError},
streaming::{StreamingCompletionModel, StreamingResult},
OneOrMany,
};
const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self {
base_url: INCEPTION_API_BASE_URL.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Content-Type",
"application/json"
.parse()
.expect("Content-Type should parse"),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Authorization should parse"),
);
headers
})
.build()
.expect("Inception reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set");
Client::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
// ================================================================
// Inception Completion API
// ================================================================
/// `mercury-coder-small` completion model
pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<Choice>,
pub object: String,
pub created: u64,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
self.prompt_tokens, self.completion_tokens, self.total_tokens
)
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
impl TryFrom<message::Message> for Message {
type Error = MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
Ok(match message {
message::Message::User { content } => Message {
role: Role::User,
content: match content.first() {
message::UserContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"User message content must be a text message".to_string(),
))
}
},
},
message::Message::Assistant { content } => Message {
role: Role::Assistant,
content: match content.first() {
message::AssistantContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"Assistant message content must be a text message".to_string(),
))
}
},
},
})
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let choice = response.choices.first().ok_or_else(|| {
CompletionError::ResponseError("Response contained no choices".to_owned())
})?;
let content = match &choice.message.role {
Role::Assistant => {
let content = completion::AssistantContent::text(&choice.message.content);
Ok(content)
}
_ => Err(CompletionError::ResponseError(
"Response did not contain a valid message".into(),
)),
}?;
let choice = OneOrMany::one(content);
Ok(completion::CompletionResponse {
choice,
raw_response: response,
})
}
}
const MAX_TOKENS: u64 = 8192;
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
pub model: String,
}
impl CompletionModel {
pub(crate) fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
let mut messages = vec![];
if let Some(preamble) = completion_request.preamble.clone() {
messages.push(Message {
role: Role::System,
content: preamble.clone(),
});
}
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let chat_history = completion_request
.chat_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.collect::<Result<Vec<Message>, _>>()?;
messages.extend(chat_history);
messages.push(prompt_message);
let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS);
let request = json!({
"model": self.model,
"messages": messages,
// The beta API reference doesn't mention temperature but it doesn't hurt to include it
"temperature": completion_request.temperature,
"max_tokens": max_tokens,
});
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
};
Ok(request)
}
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"Inception completion token usage: {}",
response.usage
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
merge_inplace(&mut request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await
}
}

View File

@ -1,91 +0,0 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{
agent::AgentBuilder, extractor::ExtractorBuilder,
providers::inception::completion::CompletionModel,
};
const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1";
#[derive(Clone)]
pub struct ClientBuilder<'a> {
api_key: &'a str,
base_url: &'a str,
}
impl<'a> ClientBuilder<'a> {
pub fn new(api_key: &'a str) -> Self {
Self {
api_key,
base_url: INCEPTION_API_BASE_URL,
}
}
pub fn base_url(mut self, base_url: &'a str) -> Self {
self.base_url = base_url;
self
}
pub fn build(self) -> Client {
Client::new(self.api_key, self.base_url)
}
}
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Content-Type",
"application/json"
.parse()
.expect("Content-Type should parse"),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Authorization should parse"),
);
headers
})
.build()
.expect("Inception reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set");
ClientBuilder::new(&api_key).build()
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

View File

@ -1,197 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{
completion::{self, CompletionError},
message::{self, MessageError},
OneOrMany,
};
use super::client::Client;
// ================================================================
// Inception Completion API
// ================================================================
/// `mercury-coder-small` completion model
pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<Choice>,
pub object: String,
pub created: u64,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
impl From<Choice> for completion::AssistantContent {
fn from(choice: Choice) -> Self {
completion::AssistantContent::from(&choice)
}
}
impl From<&Choice> for completion::AssistantContent {
fn from(choice: &Choice) -> Self {
completion::AssistantContent::Text(completion::message::Text {
text: choice.message.content.clone(),
})
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
self.prompt_tokens, self.completion_tokens, self.total_tokens
)
}
}
impl TryFrom<message::Message> for Message {
type Error = MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
Ok(match message {
message::Message::User { content } => Message {
role: Role::User,
content: match content.first() {
message::UserContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"User message content must be a text message".to_string(),
))
}
},
},
message::Message::Assistant { content } => Message {
role: Role::Assistant,
content: match content.first() {
message::AssistantContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"Assistant message content must be a text message".to_string(),
))
}
},
},
})
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let content = response.choices.iter().map(Into::into).collect::<Vec<_>>();
let choice = OneOrMany::many(content).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
Ok(completion::CompletionResponse {
choice,
raw_response: response,
})
}
}
const MAX_TOKENS: u64 = 8192;
#[derive(Clone)]
pub struct CompletionModel {
pub(crate) client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS);
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut messages = completion_request
.chat_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let request = json!({
"model": self.model,
"messages": messages,
"max_tokens": max_tokens,
});
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;
if response.status().is_success() {
let response = response.json::<CompletionResponse>().await?;
tracing::info!(target: "rig",
"Inception completion token usage: {}",
response.usage
);
Ok(response.try_into()?)
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -1,6 +0,0 @@
pub mod client;
pub mod completion;
pub mod streaming;
pub use client::{Client, ClientBuilder};
pub use completion::MERCURY_CODER_SMALL;

View File

@ -1,121 +0,0 @@
use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::json;
use super::completion::{CompletionModel, Message};
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace;
use crate::message::MessageError;
use crate::providers::anthropic::decoders::sse::from_response as sse_from_response;
use crate::streaming::{self, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)]
pub struct StreamingResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamingChoice>,
}
#[derive(Debug, Deserialize)]
pub struct StreamingChoice {
pub index: usize,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Delta {
pub content: Option<String>,
pub role: Option<String>,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut messages = completion_request
.chat_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let mut request = json!({
"model": self.model,
"messages": messages,
"max_tokens": completion_request.max_tokens.unwrap_or(8192),
"stream": true,
});
if let Some(temperature) = completion_request.temperature {
merge_inplace(&mut request, json!({ "temperature": temperature }));
}
if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone())
}
let response = self
.client
.post("chat/completions")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(response.text().await?));
}
// Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! {
let mut sse_stream = Box::pin(sse_stream);
while let Some(sse_result) = sse_stream.next().await {
match sse_result {
Ok(sse) => {
// Parse the SSE data as a StreamingResponse
match serde_json::from_str::<StreamingResponse>(&sse.data) {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone()));
}
if choice.finish_reason.as_deref() == Some("stop") {
break;
}
}
},
Err(e) => {
if !sse.data.trim().is_empty() {
yield Err(CompletionError::ResponseError(
format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
));
}
}
}
},
Err(e) => {
yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
break;
}
}
}
}))
}
}