feat: start refactoring streaming api

This commit is contained in:
yavens 2025-04-08 23:43:50 -04:00
parent 92c91d23c3
commit fcbe648f77
12 changed files with 210 additions and 53 deletions

View File

@ -8,7 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::message::MessageError; use crate::message::MessageError;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
@ -191,12 +191,12 @@ impl StreamingCompletionModel for CompletionModel {
fn handle_event( fn handle_event(
event: &StreamingEvent, event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>, current_tool_call: &mut Option<ToolCallState>,
) -> Option<Result<StreamingChoice, CompletionError>> { ) -> Option<Result<RawStreamingChoice, CompletionError>> {
match event { match event {
StreamingEvent::ContentBlockDelta { delta, .. } => match delta { StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {
if current_tool_call.is_none() { if current_tool_call.is_none() {
return Some(Ok(StreamingChoice::Message(text.clone()))); return Some(Ok(RawStreamingChoice::Message(text.clone())));
} }
None None
} }
@ -227,7 +227,7 @@ fn handle_event(
&tool_call.input_json &tool_call.input_json
}; };
match serde_json::from_str(json_str) { match serde_json::from_str(json_str) {
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall( Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
tool_call.name, tool_call.name,
tool_call.id, tool_call.id,
json_value, json_value,

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
// Azure OpenAI Streaming API // Azure OpenAI Streaming API
// ----------------------------------------------------- // -----------------------------------------------------
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type Response = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self let builder = self
.client .client

View File

@ -10,8 +10,9 @@
//! ``` //! ```
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest}, completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder, extractor::ExtractorBuilder,
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
} }
impl StreamingCompletionModel for DeepSeekCompletionModel { impl StreamingCompletionModel for DeepSeekCompletionModel {
type Response = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/v1/chat/completions").json(&request); let builder = self.client.post("/v1/chat/completions").json(&request);
send_compatible_streaming_request(builder).await send_compatible_streaming_request(builder).await

View File

@ -13,7 +13,7 @@
use super::openai; use super::openai;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type Response = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -74,9 +74,9 @@ impl StreamingCompletionModel for CompletionModel {
match choice.content.parts.first() { match choice.content.parts.first() {
super::completion::gemini_api_types::Part::Text(text) super::completion::gemini_api_types::Part::Text(text)
=> yield Ok(streaming::StreamingChoice::Message(text)), => yield Ok(streaming::RawStreamingChoice::Message(text)),
super::completion::gemini_api_types::Part::FunctionCall(function_call) super::completion::gemini_api_types::Part::FunctionCall(function_call)
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
_ => panic!("Unsupported response type with streaming.") _ => panic!("Unsupported response type with streaming.")
}; };
} }

View File

@ -10,7 +10,8 @@
//! ``` //! ```
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::providers::openai;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type Response = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent}; use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
merge_inplace(&mut request, json!({"stream": true})); merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);
@ -526,8 +530,10 @@ mod image_generation {
// ====================================== // ======================================
// Hyperbolic Audio Generation API // Hyperbolic Audio Generation API
// ====================================== // ======================================
use crate::providers::openai;
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
pub use audio_generation::*; pub use audio_generation::*;
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
mod audio_generation { mod audio_generation {
use super::{ApiResponse, Client}; use super::{ApiResponse, Client};

View File

@ -9,7 +9,7 @@
//! ``` //! ```
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -24,6 +24,7 @@ use serde_json::{json, Value};
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use thiserror::Error; use thiserror::Error;
use tracing; use tracing;
use crate::providers::openai;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum MiraError { pub enum MiraError {
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -11,7 +11,7 @@
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type Response = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -39,7 +39,7 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2"); //! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ``` //! ```
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -47,7 +47,7 @@ use crate::{
extractor::ExtractorBuilder, extractor::ExtractorBuilder,
json_utils, message, json_utils, message,
message::{ImageDetail, Text}, message::{ImageDetail, Text},
Embed, OneOrMany, streaming, Embed, OneOrMany,
}; };
use async_stream::stream; use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
@ -405,8 +405,30 @@ impl completion::CompletionModel for CompletionModel {
} }
} }
pub struct StreamingCompletionResponse {
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub load_duration: Option<u64>,
#[serde(default)]
pub prompt_eval_count: Option<u64>,
#[serde(default)]
pub prompt_eval_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u64>,
#[serde(default)]
pub eval_duration: Option<u64>,
}
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type Response = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request_payload = self.create_completion_request(request)?; let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true})); merge_inplace(&mut request_payload, json!({"stream": true}));
@ -426,7 +448,8 @@ impl StreamingCompletionModel for CompletionModel {
return Err(CompletionError::ProviderError(err_text)); return Err(CompletionError::ProviderError(err_text));
} }
Ok(Box::pin(stream! { let mut
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result { let chunk = match chunk_result {
@ -456,13 +479,13 @@ impl StreamingCompletionModel for CompletionModel {
match response.message { match response.message {
Message::Assistant{ content, tool_calls, .. } => { Message::Assistant{ content, tool_calls, .. } => {
if !content.is_empty() { if !content.is_empty() {
yield Ok(StreamingChoice::Message(content)) yield Ok(RawStreamingChoice::Message(content))
} }
for tool_call in tool_calls.iter() { for tool_call in tool_calls.iter() {
let function = tool_call.function.clone(); let function = tool_call.function.clone();
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
} }
} }
_ => { _ => {
@ -471,7 +494,7 @@ impl StreamingCompletionModel for CompletionModel {
} }
} }
} }
})) });
} }
} }

View File

@ -2,6 +2,7 @@ use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils; use crate::json_utils;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::Usage;
use crate::streaming; use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingResult};
use async_stream::stream; use async_stream::stream;
@ -10,6 +11,7 @@ use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::stream;
// ================================================================ // ================================================================
// OpenAI Completion Streaming API // OpenAI Completion Streaming API
@ -42,15 +44,21 @@ struct StreamingChoice {
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct StreamingCompletionResponse { struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>, choices: Vec<StreamingChoice>,
usage: Option<Usage>,
}
pub struct StreamingCompletionResponse {
usage: Option<Usage>,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));
@ -61,7 +69,7 @@ impl StreamingCompletionModel for CompletionModel {
pub async fn send_compatible_streaming_request( pub async fn send_compatible_streaming_request(
request_builder: RequestBuilder, request_builder: RequestBuilder,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
let response = request_builder.send().await?; let response = request_builder.send().await?;
if !response.status().is_success() { if !response.status().is_success() {
@ -73,7 +81,7 @@ pub async fn send_compatible_streaming_request(
} }
// Handle OpenAI Compatible SSE chunks // Handle OpenAI Compatible SSE chunks
Ok(Box::pin(stream! { let inner = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
let mut partial_data = None; let mut partial_data = None;
@ -121,7 +129,7 @@ pub async fn send_compatible_streaming_request(
} }
} }
let data = serde_json::from_str::<StreamingCompletionResponse>(&line); let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
let Ok(data) = data else { let Ok(data) = data else {
continue; continue;
@ -162,13 +170,17 @@ pub async fn send_compatible_streaming_request(
continue; continue;
}; };
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
} }
} }
} }
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone())) yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
if &data.usage.is_some() {
usage = data.usage;
} }
} }
} }
@ -178,7 +190,11 @@ pub async fn send_compatible_streaming_request(
continue; continue;
}; };
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(
inner,
))
} }

View File

@ -13,14 +13,28 @@ use crate::agent::Agent;
use crate::completion::{ use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
}; };
use crate::message::AssistantContent;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::boxed::Box; use std::boxed::Box;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll};
/// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)]
pub enum RawStreamingChoice<R> {
/// A text chunk from a message response
Message(String),
/// A tool call response chunk
ToolCall(String, String, serde_json::Value),
/// The final response object
FinalResponse(R),
}
/// Enum representing a streaming chunk from the model /// Enum representing a streaming chunk from the model
#[derive(Debug)]
pub enum StreamingChoice { pub enum StreamingChoice {
/// A text chunk from a message response /// A text chunk from a message response
Message(String), Message(String),
@ -41,29 +55,87 @@ impl Display for StreamingChoice {
} }
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub type StreamingResult = pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>; Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>; pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
pub struct StreamingCompletionResponse<R> {
inner: StreamingResult<R>,
text: String,
tool_calls: Vec<(String, String, serde_json::Value)>,
pub message: Message,
pub response: Option<R>,
}
impl<R> StreamingCompletionResponse<R> {
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
Self {
inner,
text: "".to_string(),
tool_calls: vec![],
message: Message::assistant(""),
response: None,
}
}
}
impl<R> Stream for StreamingCompletionResponse<R> {
type Item = Result<StreamingChoice, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
let content = vec![AssistantContent::text(self.text.clone())];
self.tool_calls
.iter()
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
self.message = Message::Assistant {
content: content.into(),
}
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => {
self.text = format!("{}{}", self.text, text);
Poll::Ready(Some(Ok(choice.clone())))
}
RawStreamingChoice::ToolCall(name, description, args) => {
self.tool_calls
.push((name, description, args));
Poll::Ready(Some(Ok(choice.clone())))
}
RawStreamingChoice::FinalResponse(response) => {
self.response = Some(response);
Poll::Pending
}
},
}
}
}
/// Trait for high-level streaming prompt interface /// Trait for high-level streaming prompt interface
pub trait StreamingPrompt: Send + Sync { pub trait StreamingPrompt<R>: Send + Sync {
/// Stream a simple prompt to the model /// Stream a simple prompt to the model
fn stream_prompt( fn stream_prompt(
&self, &self,
prompt: &str, prompt: &str,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
} }
/// Trait for high-level streaming chat interface /// Trait for high-level streaming chat interface
pub trait StreamingChat: Send + Sync { pub trait StreamingChat<R>: Send + Sync {
/// Stream a chat with history to the model /// Stream a chat with history to the model
fn stream_chat( fn stream_chat(
&self, &self,
prompt: &str, prompt: &str,
chat_history: Vec<Message>, chat_history: Vec<Message>,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
} }
/// Trait for low-level streaming completion interface /// Trait for low-level streaming completion interface
@ -78,17 +150,18 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
/// Trait defining a streaming completion model /// Trait defining a streaming completion model
pub trait StreamingCompletionModel: CompletionModel { pub trait StreamingCompletionModel: CompletionModel {
type Response;
/// Stream a completion response for the given request /// Stream a completion response for the given request
fn stream( fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
} }
/// helper function to stream a completion request to stdout /// helper function to stream a completion request to stdout
pub async fn stream_to_stdout<M: StreamingCompletionModel>( pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
agent: Agent<M>, agent: Agent<M>,
stream: &mut StreamingResult, stream: &mut StreamingCompletionResponse<R>,
) -> Result<(), std::io::Error> { ) -> Result<(), std::io::Error> {
print!("Response: "); print!("Response: ");
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
@ -111,6 +184,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
} }
} }
} }
println!(); // New line after streaming completes println!(); // New line after streaming completes
Ok(()) Ok(())