Initial Inception provider support with working completions

This commit is contained in:
Collin Brittain 2025-04-11 11:18:12 -05:00
parent c7d4851e32
commit 391e6c87d3
4 changed files with 292 additions and 0 deletions

View File

@ -0,0 +1,91 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{
agent::AgentBuilder, extractor::ExtractorBuilder,
providers::inception::completion::CompletionModel,
};
const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1";
#[derive(Clone)]
pub struct ClientBuilder<'a> {
api_key: &'a str,
base_url: &'a str,
}
impl<'a> ClientBuilder<'a> {
pub fn new(api_key: &'a str) -> Self {
Self {
api_key,
base_url: INCEPTION_API_BASE_URL,
}
}
pub fn base_url(mut self, base_url: &'a str) -> Self {
self.base_url = base_url;
self
}
pub fn build(self) -> Client {
Client::new(self.api_key, self.base_url)
}
}
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Content-Type",
"application/json"
.parse()
.expect("Content-Type should parse"),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Authorization should parse"),
);
headers
})
.build()
.expect("Inception reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set");
ClientBuilder::new(&api_key).build()
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

View File

@ -0,0 +1,197 @@
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{
completion::{self, CompletionError},
message::{self, MessageError},
OneOrMany,
};
use super::client::Client;
// ================================================================
// Inception Completion API
// ================================================================
/// `mercury-coder-small` completion model
pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<Choice>,
pub object: String,
pub created: u64,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
impl From<Choice> for completion::AssistantContent {
fn from(choice: Choice) -> Self {
completion::AssistantContent::from(&choice)
}
}
impl From<&Choice> for completion::AssistantContent {
fn from(choice: &Choice) -> Self {
completion::AssistantContent::Text(completion::message::Text {
text: choice.message.content.clone(),
})
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
self.prompt_tokens, self.completion_tokens, self.total_tokens
)
}
}
impl TryFrom<message::Message> for Message {
type Error = MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
Ok(match message {
message::Message::User { content } => Message {
role: Role::User,
content: match content.first() {
message::UserContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"User message content must be a text message".to_string(),
))
}
},
},
message::Message::Assistant { content } => Message {
role: Role::Assistant,
content: match content.first() {
message::AssistantContent::Text(message::Text { text }) => text.clone(),
_ => {
return Err(MessageError::ConversionError(
"Assistant message content must be a text message".to_string(),
))
}
},
},
})
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let content = response.choices.iter().map(Into::into).collect::<Vec<_>>();
let choice = OneOrMany::many(content).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
Ok(completion::CompletionResponse {
choice,
raw_response: response,
})
}
}
const MAX_TOKENS: u64 = 8192;
#[derive(Clone)]
pub struct CompletionModel {
pub(crate) client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS);
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut messages = completion_request
.chat_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let request = json!({
"model": self.model,
"messages": messages,
"max_tokens": max_tokens,
});
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;
if response.status().is_success() {
let response = response.json::<CompletionResponse>().await?;
tracing::info!(target: "rig",
"Inception completion token usage: {}",
response.usage
);
Ok(response.try_into()?)
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -0,0 +1,2 @@
pub mod client;
pub mod completion;

View File

@ -11,6 +11,7 @@
//! - DeepSeek
//! - Azure OpenAI
//! - Mira
//! - Inception
//!
//! Each provider has its own module, which contains a `Client` implementation that can
//! be used to initialize completion and embedding models and execute requests to those models.
@ -54,6 +55,7 @@ pub mod gemini;
pub mod groq;
pub mod huggingface;
pub mod hyperbolic;
pub mod inception;
pub mod mira;
pub mod moonshot;
pub mod ollama;