mirror of https://github.com/0xplaygrounds/rig
Revert "Simplify inception provider module to use openai streaming api"
This reverts commit 01c93176cf
.
This commit is contained in:
parent
01c93176cf
commit
efcf180798
|
@ -2,18 +2,21 @@ use std::env;
|
|||
|
||||
use rig::{
|
||||
completion::Prompt,
|
||||
providers::inception::{Client, MERCURY_CODER_SMALL},
|
||||
providers::inception::{ClientBuilder, MERCURY_CODER_SMALL},
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create Inception Labs client
|
||||
let client = Client::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"));
|
||||
let client =
|
||||
ClientBuilder::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"))
|
||||
.build();
|
||||
|
||||
// Create agent with a single context prompt
|
||||
let agent = client
|
||||
.agent(MERCURY_CODER_SMALL)
|
||||
.preamble("You are a helpful AI assistant.")
|
||||
.temperature(0.0)
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use rig::{
|
||||
providers::inception::{self, MERCURY_CODER_SMALL},
|
||||
providers::inception::{self, completion::MERCURY_CODER_SMALL},
|
||||
streaming::{stream_to_stdout, StreamingPrompt},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,325 +0,0 @@
|
|||
use super::openai::send_compatible_streaming_request;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
json_utils::{self, merge_inplace},
|
||||
message::{self, MessageError},
|
||||
streaming::{StreamingCompletionModel, StreamingResult},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self {
|
||||
base_url: INCEPTION_API_BASE_URL.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
"application/json"
|
||||
.parse()
|
||||
.expect("Content-Type should parse"),
|
||||
);
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Authorization should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Inception reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set");
|
||||
Client::new(&api_key)
|
||||
}
|
||||
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Inception Completion API
|
||||
// ================================================================
|
||||
/// `mercury-coder-small` completion model
|
||||
pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl From<ApiErrorResponse> for CompletionError {
|
||||
fn from(err: ApiErrorResponse) -> Self {
|
||||
CompletionError::ProviderError(err.message)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
|
||||
self.prompt_tokens, self.completion_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
impl TryFrom<message::Message> for Message {
|
||||
type Error = MessageError;
|
||||
|
||||
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
|
||||
Ok(match message {
|
||||
message::Message::User { content } => Message {
|
||||
role: Role::User,
|
||||
content: match content.first() {
|
||||
message::UserContent::Text(message::Text { text }) => text.clone(),
|
||||
_ => {
|
||||
return Err(MessageError::ConversionError(
|
||||
"User message content must be a text message".to_string(),
|
||||
))
|
||||
}
|
||||
},
|
||||
},
|
||||
message::Message::Assistant { content } => Message {
|
||||
role: Role::Assistant,
|
||||
content: match content.first() {
|
||||
message::AssistantContent::Text(message::Text { text }) => text.clone(),
|
||||
_ => {
|
||||
return Err(MessageError::ConversionError(
|
||||
"Assistant message content must be a text message".to_string(),
|
||||
))
|
||||
}
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
|
||||
let choice = response.choices.first().ok_or_else(|| {
|
||||
CompletionError::ResponseError("Response contained no choices".to_owned())
|
||||
})?;
|
||||
|
||||
let content = match &choice.message.role {
|
||||
Role::Assistant => {
|
||||
let content = completion::AssistantContent::text(&choice.message.content);
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
_ => Err(CompletionError::ResponseError(
|
||||
"Response did not contain a valid message".into(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
let choice = OneOrMany::one(content);
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice,
|
||||
raw_response: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_TOKENS: u64 = 8192;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if let Some(preamble) = completion_request.preamble.clone() {
|
||||
messages.push(Message {
|
||||
role: Role::System,
|
||||
content: preamble.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let prompt_message: Message = completion_request
|
||||
.prompt_with_context()
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||
|
||||
let chat_history = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||
})
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
messages.extend(chat_history);
|
||||
messages.push(prompt_message);
|
||||
|
||||
let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS);
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
// The beta API reference doesn't mention temperature but it doesn't hurt to include it
|
||||
"temperature": completion_request.temperature,
|
||||
"max_tokens": max_tokens,
|
||||
});
|
||||
|
||||
let request = if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<CompletionResponse>>().await? {
|
||||
ApiResponse::Ok(response) => {
|
||||
tracing::info!(target: "rig",
|
||||
"Inception completion token usage: {}",
|
||||
response.usage
|
||||
);
|
||||
|
||||
response.try_into()
|
||||
}
|
||||
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
|
||||
}
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
merge_inplace(&mut request, json!({"stream": true}));
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
send_compatible_streaming_request(builder).await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
agent::AgentBuilder, extractor::ExtractorBuilder,
|
||||
providers::inception::completion::CompletionModel,
|
||||
};
|
||||
|
||||
const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClientBuilder<'a> {
|
||||
api_key: &'a str,
|
||||
base_url: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> ClientBuilder<'a> {
|
||||
pub fn new(api_key: &'a str) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
base_url: INCEPTION_API_BASE_URL,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_url(mut self, base_url: &'a str) -> Self {
|
||||
self.base_url = base_url;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Client {
|
||||
Client::new(self.api_key, self.base_url)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
"application/json"
|
||||
.parse()
|
||||
.expect("Content-Type should parse"),
|
||||
);
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Authorization should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Inception reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set");
|
||||
ClientBuilder::new(&api_key).build()
|
||||
}
|
||||
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
completion::{self, CompletionError},
|
||||
message::{self, MessageError},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use super::client::Client;
|
||||
|
||||
// ================================================================
|
||||
// Inception Completion API
|
||||
// ================================================================
|
||||
/// `mercury-coder-small` completion model
|
||||
pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
impl From<Choice> for completion::AssistantContent {
|
||||
fn from(choice: Choice) -> Self {
|
||||
completion::AssistantContent::from(&choice)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Choice> for completion::AssistantContent {
|
||||
fn from(choice: &Choice) -> Self {
|
||||
completion::AssistantContent::Text(completion::message::Text {
|
||||
text: choice.message.content.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
|
||||
self.prompt_tokens, self.completion_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<message::Message> for Message {
|
||||
type Error = MessageError;
|
||||
|
||||
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
|
||||
Ok(match message {
|
||||
message::Message::User { content } => Message {
|
||||
role: Role::User,
|
||||
content: match content.first() {
|
||||
message::UserContent::Text(message::Text { text }) => text.clone(),
|
||||
_ => {
|
||||
return Err(MessageError::ConversionError(
|
||||
"User message content must be a text message".to_string(),
|
||||
))
|
||||
}
|
||||
},
|
||||
},
|
||||
message::Message::Assistant { content } => Message {
|
||||
role: Role::Assistant,
|
||||
content: match content.first() {
|
||||
message::AssistantContent::Text(message::Text { text }) => text.clone(),
|
||||
_ => {
|
||||
return Err(MessageError::ConversionError(
|
||||
"Assistant message content must be a text message".to_string(),
|
||||
))
|
||||
}
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
|
||||
let content = response.choices.iter().map(Into::into).collect::<Vec<_>>();
|
||||
|
||||
let choice = OneOrMany::many(content).map_err(|_| {
|
||||
CompletionError::ResponseError(
|
||||
"Response contained no message or tool call (empty)".to_owned(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice,
|
||||
raw_response: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_TOKENS: u64 = 8192;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
pub(crate) client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS);
|
||||
|
||||
let prompt_message: Message = completion_request
|
||||
.prompt_with_context()
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||
|
||||
let mut messages = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||
})
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
messages.push(prompt_message);
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let response = response.json::<CompletionResponse>().await?;
|
||||
tracing::info!(target: "rig",
|
||||
"Inception completion token usage: {}",
|
||||
response.usage
|
||||
);
|
||||
Ok(response.try_into()?)
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod streaming;
|
||||
|
||||
pub use client::{Client, ClientBuilder};
|
||||
pub use completion::MERCURY_CODER_SMALL;
|
|
@ -0,0 +1,121 @@
|
|||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
use super::completion::{CompletionModel, Message};
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message::MessageError;
|
||||
use crate::providers::anthropic::decoders::sse::from_response as sse_from_response;
|
||||
use crate::streaming::{self, StreamingCompletionModel, StreamingResult};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamingResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamingChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamingChoice {
|
||||
pub index: usize,
|
||||
pub delta: Delta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Delta {
|
||||
pub content: Option<String>,
|
||||
pub role: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
let prompt_message: Message = completion_request
|
||||
.prompt_with_context()
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||
|
||||
let mut messages = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||
})
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
messages.push(prompt_message);
|
||||
|
||||
let mut request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": completion_request.max_tokens.unwrap_or(8192),
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
if let Some(temperature) = completion_request.temperature {
|
||||
merge_inplace(&mut request, json!({ "temperature": temperature }));
|
||||
}
|
||||
|
||||
if let Some(ref params) = completion_request.additional_params {
|
||||
merge_inplace(&mut request, params.clone())
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("chat/completions")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(response.text().await?));
|
||||
}
|
||||
|
||||
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||
let sse_stream = sse_from_response(response);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let mut sse_stream = Box::pin(sse_stream);
|
||||
|
||||
while let Some(sse_result) = sse_stream.next().await {
|
||||
match sse_result {
|
||||
Ok(sse) => {
|
||||
// Parse the SSE data as a StreamingResponse
|
||||
match serde_json::from_str::<StreamingResponse>(&sse.data) {
|
||||
Ok(response) => {
|
||||
if let Some(choice) = response.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
yield Ok(streaming::StreamingChoice::Message(content.clone()));
|
||||
}
|
||||
if choice.finish_reason.as_deref() == Some("stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
if !sse.data.trim().is_empty() {
|
||||
yield Err(CompletionError::ResponseError(
|
||||
format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue