diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index b351515..85159f4 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -8,7 +8,7 @@ use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge_inplace; use crate::message::MessageError; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -191,12 +191,12 @@ impl StreamingCompletionModel for CompletionModel { fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, -) -> Option> { +) -> Option> { match event { StreamingEvent::ContentBlockDelta { delta, .. } => match delta { ContentDelta::TextDelta { text } => { if current_tool_call.is_none() { - return Some(Ok(StreamingChoice::Message(text.clone()))); + return Some(Ok(RawStreamingChoice::Message(text.clone()))); } None } @@ -227,7 +227,7 @@ fn handle_event( &tool_call.input_json }; match serde_json::from_str(json_str) { - Ok(json_value) => Some(Ok(StreamingChoice::ToolCall( + Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall( tool_call.name, tool_call.id, json_value, diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index c2e0ec9..4c3b8b3 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel { // Azure OpenAI Streaming API // ----------------------------------------------------- impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self .client diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 86d9a0f..71a058b 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -10,8 +10,9 @@ //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ completion::{self, CompletionError, CompletionModel, CompletionRequest}, extractor::ExtractorBuilder, @@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel { } impl StreamingCompletionModel for DeepSeekCompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/v1/chat/completions").json(&request); send_compatible_streaming_request(builder).await diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 3b536c9..84d9bf8 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -13,7 +13,7 @@ use super::openai; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 362d9ae..b3240bc 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -74,9 +74,9 @@ impl StreamingCompletionModel for CompletionModel { match choice.content.parts.first() { super::completion::gemini_api_types::Part::Text(text) - => yield Ok(streaming::StreamingChoice::Message(text)), + => yield Ok(streaming::RawStreamingChoice::Message(text)), super::completion::gemini_api_types::Part::FunctionCall(function_call) - => yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), + => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), _ => panic!("Unsupported response type with streaming.") }; } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index f852e77..08ff2ee 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -10,7 +10,8 @@ //! ``` use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 1606635..00098e6 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, AssistantContent}; use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); @@ -526,8 +530,10 @@ mod image_generation { // ====================================== // Hyperbolic Audio Generation API // ====================================== +use crate::providers::openai; #[cfg(feature = "audio")] pub use audio_generation::*; + #[cfg(feature = "audio")] mod audio_generation { use super::{ApiResponse, Client}; diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 1d11c19..1f43eb6 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -9,7 +9,7 @@ //! ``` use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -24,6 +24,7 @@ use serde_json::{json, Value}; use std::string::FromUtf8Error; use thiserror::Error; use tracing; +use crate::providers::openai; #[derive(Debug, Error)] pub enum MiraError { @@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index 278863a..2906a89 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,7 +11,7 @@ use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 4b539b8..1429ad0 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,7 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -47,7 +47,7 @@ use crate::{ extractor::ExtractorBuilder, json_utils, message, message::{ImageDetail, Text}, - Embed, OneOrMany, + streaming, Embed, OneOrMany, }; use async_stream::stream; use futures::StreamExt; @@ -405,8 +405,30 @@ impl completion::CompletionModel for CompletionModel { } } +pub struct StreamingCompletionResponse { + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request_payload = self.create_completion_request(request)?; merge_inplace(&mut request_payload, json!({"stream": true})); @@ -426,7 +448,8 @@ impl StreamingCompletionModel for CompletionModel { return Err(CompletionError::ProviderError(err_text)); } - Ok(Box::pin(stream! { + let mut + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { @@ -456,13 +479,13 @@ impl StreamingCompletionModel for CompletionModel { match response.message { Message::Assistant{ content, tool_calls, .. } => { if !content.is_empty() { - yield Ok(StreamingChoice::Message(content)) + yield Ok(RawStreamingChoice::Message(content)) } for tool_call in tool_calls.iter() { let function = tool_call.function.clone(); - yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); + yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); } } _ => { @@ -471,7 +494,7 @@ impl StreamingCompletionModel for CompletionModel { } } } - })) + }); } } diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index 3727468..af68d58 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -2,6 +2,7 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils; use crate::json_utils::merge; +use crate::providers::openai::Usage; use crate::streaming; use crate::streaming::{StreamingCompletionModel, StreamingResult}; use async_stream::stream; @@ -10,6 +11,7 @@ use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; +use tokio::stream; // ================================================================ // OpenAI Completion Streaming API @@ -42,15 +44,21 @@ struct StreamingChoice { } #[derive(Deserialize)] -struct StreamingCompletionResponse { +struct StreamingCompletionChunk { choices: Vec, + usage: Option, +} + +pub struct StreamingCompletionResponse { + usage: Option, } impl StreamingCompletionModel for CompletionModel { + type Response = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); @@ -61,7 +69,7 @@ impl StreamingCompletionModel for CompletionModel { pub async fn send_compatible_streaming_request( request_builder: RequestBuilder, -) -> Result { +) -> Result, CompletionError> { let response = request_builder.send().await?; if !response.status().is_success() { @@ -73,7 +81,7 @@ pub async fn send_compatible_streaming_request( } // Handle OpenAI Compatible SSE chunks - Ok(Box::pin(stream! { + let inner = Box::pin(stream! { let mut stream = response.bytes_stream(); let mut partial_data = None; @@ -121,7 +129,7 @@ pub async fn send_compatible_streaming_request( } } - let data = serde_json::from_str::(&line); + let data = serde_json::from_str::(&line); let Ok(data) = data else { continue; @@ -162,13 +170,17 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } } } if let Some(content) = &choice.delta.content { - yield Ok(streaming::StreamingChoice::Message(content.clone())) + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + } + + if &data.usage.is_some() { + usage = data.usage; } } } @@ -178,7 +190,11 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new( + inner, + )) } diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 76f91c4..62e7ec2 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -13,14 +13,28 @@ use crate::agent::Agent; use crate::completion::{ CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, }; +use crate::message::AssistantContent; use futures::{Stream, StreamExt}; use std::boxed::Box; use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Enum representing a streaming chunk from the model +#[derive(Debug, Clone)] +pub enum RawStreamingChoice { + /// A text chunk from a message response + Message(String), + + /// A tool call response chunk + ToolCall(String, String, serde_json::Value), + + /// The final response object + FinalResponse(R), +} /// Enum representing a streaming chunk from the model -#[derive(Debug)] pub enum StreamingChoice { /// A text chunk from a message response Message(String), @@ -41,29 +55,87 @@ impl Display for StreamingChoice { } #[cfg(not(target_arch = "wasm32"))] -pub type StreamingResult = - Pin> + Send>>; +pub type StreamingResult = + Pin, CompletionError>> + Send>>; #[cfg(target_arch = "wasm32")] -pub type StreamingResult = Pin>>>; +pub type StreamingResult = Pin>>>; + +pub struct StreamingCompletionResponse { + inner: StreamingResult, + text: String, + tool_calls: Vec<(String, String, serde_json::Value)>, + pub message: Message, + pub response: Option, +} + +impl StreamingCompletionResponse { + pub fn new(inner: StreamingResult) -> StreamingCompletionResponse { + Self { + inner, + text: "".to_string(), + tool_calls: vec![], + message: Message::assistant(""), + response: None, + } + } +} + +impl Stream for StreamingCompletionResponse { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_next(cx) { + Poll::Pending => Poll::Pending, + + Poll::Ready(None) => { + let content = vec![AssistantContent::text(self.text.clone())]; + + self.tool_calls + .iter() + .for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a)); + + self.message = Message::Assistant { + content: content.into(), + } + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(choice))) => match choice { + RawStreamingChoice::Message(text) => { + self.text = format!("{}{}", self.text, text); + Poll::Ready(Some(Ok(choice.clone()))) + } + RawStreamingChoice::ToolCall(name, description, args) => { + self.tool_calls + .push((name, description, args)); + Poll::Ready(Some(Ok(choice.clone()))) + } + RawStreamingChoice::FinalResponse(response) => { + self.response = Some(response); + Poll::Pending + } + }, + } + } +} /// Trait for high-level streaming prompt interface -pub trait StreamingPrompt: Send + Sync { +pub trait StreamingPrompt: Send + Sync { /// Stream a simple prompt to the model fn stream_prompt( &self, prompt: &str, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync { +pub trait StreamingChat: Send + Sync { /// Stream a chat with history to the model fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for low-level streaming completion interface @@ -78,17 +150,18 @@ pub trait StreamingCompletion { /// Trait defining a streaming completion model pub trait StreamingCompletionModel: CompletionModel { + type Response; /// Stream a completion response for the given request fn stream( &self, request: CompletionRequest, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// helper function to stream a completion request to stdout -pub async fn stream_to_stdout( +pub async fn stream_to_stdout( agent: Agent, - stream: &mut StreamingResult, + stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> { print!("Response: "); while let Some(chunk) = stream.next().await { @@ -111,6 +184,7 @@ pub async fn stream_to_stdout( } } } + println!(); // New line after streaming completes Ok(())