Initial streaming work

This commit is contained in:
Collin Brittain 2025-04-11 11:54:19 -05:00
parent 329805c9ee
commit 2736b5a4d6
3 changed files with 144 additions and 0 deletions

View File

@ -0,0 +1,22 @@
use rig::{
providers::inception::{self, completion::MERCURY_CODER_SMALL},
streaming::{stream_to_stdout, StreamingPrompt},
};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = inception::Client::from_env()
.agent(MERCURY_CODER_SMALL)
.preamble("Be precise and concise.")
.build();
// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await?;
stream_to_stdout(agent, &mut stream).await?;
Ok(())
}

View File

@ -1,5 +1,6 @@
pub mod client; pub mod client;
pub mod completion; pub mod completion;
pub mod streaming;
pub use client::{Client, ClientBuilder}; pub use client::{Client, ClientBuilder};
pub use completion::MERCURY_CODER_SMALL; pub use completion::MERCURY_CODER_SMALL;

View File

@ -0,0 +1,121 @@
use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::json;
use super::completion::{CompletionModel, Message};
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace;
use crate::message::MessageError;
use crate::providers::anthropic::decoders::sse::from_response as sse_from_response;
use crate::streaming::{self, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)]
pub struct StreamingResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamingChoice>,
}
#[derive(Debug, Deserialize)]
pub struct StreamingChoice {
pub index: usize,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Delta {
pub content: Option<String>,
pub role: Option<String>,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut messages = completion_request
.chat_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let mut request = json!({
"model": self.model,
"messages": messages,
"max_tokens": completion_request.max_tokens.unwrap_or(8192),
"stream": true,
});
if let Some(temperature) = completion_request.temperature {
merge_inplace(&mut request, json!({ "temperature": temperature }));
}
if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone())
}
let response = self
.client
.post("chat/completions")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(response.text().await?));
}
// Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! {
let mut sse_stream = Box::pin(sse_stream);
while let Some(sse_result) = sse_stream.next().await {
match sse_result {
Ok(sse) => {
// Parse the SSE data as a StreamingResponse
match serde_json::from_str::<StreamingResponse>(&sse.data) {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone()));
}
if choice.finish_reason.as_deref() == Some("stop") {
break;
}
}
},
Err(e) => {
if !sse.data.trim().is_empty() {
yield Err(CompletionError::ResponseError(
format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
));
}
}
}
},
Err(e) => {
yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
break;
}
}
}
}))
}
}