mirror of https://github.com/0xplaygrounds/rig
Initial streaming work
This commit is contained in:
parent
329805c9ee
commit
2736b5a4d6
|
@ -0,0 +1,22 @@
|
||||||
|
use rig::{
|
||||||
|
providers::inception::{self, completion::MERCURY_CODER_SMALL},
|
||||||
|
streaming::{stream_to_stdout, StreamingPrompt},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
// Create streaming agent with a single context prompt
|
||||||
|
let agent = inception::Client::from_env()
|
||||||
|
.agent(MERCURY_CODER_SMALL)
|
||||||
|
.preamble("Be precise and concise.")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Stream the response and print chunks as they arrive
|
||||||
|
let mut stream = agent
|
||||||
|
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
|
pub mod streaming;
|
||||||
|
|
||||||
pub use client::{Client, ClientBuilder};
|
pub use client::{Client, ClientBuilder};
|
||||||
pub use completion::MERCURY_CODER_SMALL;
|
pub use completion::MERCURY_CODER_SMALL;
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use super::completion::{CompletionModel, Message};
|
||||||
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
|
use crate::json_utils::merge_inplace;
|
||||||
|
use crate::message::MessageError;
|
||||||
|
use crate::providers::anthropic::decoders::sse::from_response as sse_from_response;
|
||||||
|
use crate::streaming::{self, StreamingCompletionModel, StreamingResult};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct StreamingResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<StreamingChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct StreamingChoice {
|
||||||
|
pub index: usize,
|
||||||
|
pub delta: Delta,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Delta {
|
||||||
|
pub content: Option<String>,
|
||||||
|
pub role: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
completion_request: CompletionRequest,
|
||||||
|
) -> Result<StreamingResult, CompletionError> {
|
||||||
|
let prompt_message: Message = completion_request
|
||||||
|
.prompt_with_context()
|
||||||
|
.try_into()
|
||||||
|
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||||
|
|
||||||
|
let mut messages = completion_request
|
||||||
|
.chat_history
|
||||||
|
.into_iter()
|
||||||
|
.map(|message| {
|
||||||
|
message
|
||||||
|
.try_into()
|
||||||
|
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Message>, _>>()?;
|
||||||
|
|
||||||
|
messages.push(prompt_message);
|
||||||
|
|
||||||
|
let mut request = json!({
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": completion_request.max_tokens.unwrap_or(8192),
|
||||||
|
"stream": true,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(temperature) = completion_request.temperature {
|
||||||
|
merge_inplace(&mut request, json!({ "temperature": temperature }));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref params) = completion_request.additional_params {
|
||||||
|
merge_inplace(&mut request, params.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post("chat/completions")
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(CompletionError::ProviderError(response.text().await?));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||||
|
let sse_stream = sse_from_response(response);
|
||||||
|
|
||||||
|
Ok(Box::pin(stream! {
|
||||||
|
let mut sse_stream = Box::pin(sse_stream);
|
||||||
|
|
||||||
|
while let Some(sse_result) = sse_stream.next().await {
|
||||||
|
match sse_result {
|
||||||
|
Ok(sse) => {
|
||||||
|
// Parse the SSE data as a StreamingResponse
|
||||||
|
match serde_json::from_str::<StreamingResponse>(&sse.data) {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(choice) = response.choices.first() {
|
||||||
|
if let Some(content) = &choice.delta.content {
|
||||||
|
yield Ok(streaming::StreamingChoice::Message(content.clone()));
|
||||||
|
}
|
||||||
|
if choice.finish_reason.as_deref() == Some("stop") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
if !sse.data.trim().is_empty() {
|
||||||
|
yield Err(CompletionError::ResponseError(
|
||||||
|
format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue