diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index d4aadf0..a474702 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> { .await?; stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; Ok(()) } diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 781e2a8..bf5ab0d 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -110,23 +110,20 @@ use std::collections::HashMap; use futures::{stream, StreamExt, TryStreamExt}; +use crate::streaming::StreamingCompletionResponse; +#[cfg(feature = "mcp")] +use crate::tool::McpTool; use crate::{ completion::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Message, Prompt, PromptError, }, message::AssistantContent, - streaming::{ - StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, - StreamingResult, - }, + streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt}, tool::{Tool, ToolSet}, vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; -#[cfg(feature = "mcp")] -use crate::tool::McpTool; - /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// (i.e.: system prompt) and a static set of context documents and tools. /// All context documents and tools are always provided to the agent when prompted. @@ -500,18 +497,21 @@ impl StreamingCompletion for Agent { } } -impl StreamingPrompt for Agent { - async fn stream_prompt(&self, prompt: &str) -> Result { +impl StreamingPrompt for Agent { + async fn stream_prompt( + &self, + prompt: &str, + ) -> Result, CompletionError> { self.stream_chat(prompt, vec![]).await } } -impl StreamingChat for Agent { +impl StreamingChat for Agent { async fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> Result { + ) -> Result, CompletionError> { self.stream_completion(prompt, chat_history) .await? .stream() diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..07e8f64 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -67,7 +67,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::OneOrMany; use crate::{ json_utils, @@ -467,7 +467,9 @@ impl CompletionRequestBuilder { impl CompletionRequestBuilder { /// Stream the completion request - pub async fn stream(self) -> Result { + pub async fn stream( + self, + ) -> Result, CompletionError> { let model = self.model.clone(); model.stream(self.build()).await } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 85159f4..c343b44 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge_inplace; use crate::message::MessageError; +use crate::streaming; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; #[derive(Debug, Deserialize)] @@ -61,7 +62,7 @@ pub struct MessageDelta { pub stop_sequence: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct PartialUsage { pub output_tokens: usize, #[serde(default)] @@ -75,11 +76,18 @@ struct ToolCallState { input_json: String, } +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: PartialUsage, +} + impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let max_tokens = if let Some(tokens) = completion_request.max_tokens { tokens } else if let Some(tokens) = self.default_max_tokens { @@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel { // Use our SSE decoder to directly handle Server-Sent Events format let sse_stream = sse_from_response(response); - Ok(Box::pin(stream! { + let stream: StreamingResult = Box::pin(stream! { let mut current_tool_call: Option = None; let mut sse_stream = Box::pin(sse_stream); + let mut input_tokens = 0; while let Some(sse_result) = sse_stream.next().await { match sse_result { @@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel { // Parse the SSE data as a StreamingEvent match serde_json::from_str::(&sse.data) { Ok(event) => { + match &event { + StreamingEvent::MessageStart { message } => { + input_tokens = message.usage.input_tokens; + }, + StreamingEvent::MessageDelta { delta, usage } => { + if delta.stop_reason.is_some() { + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: PartialUsage { + output_tokens: usage.output_tokens, + input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), + } + })) + } + } + _ => {} + } + if let Some(result) = handle_event(&event, &mut current_tool_call) { yield result; } @@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel { } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, -) -> Option> { +) -> Option, CompletionError>> { match event { StreamingEvent::ContentBlockDelta { delta, .. } => match delta { ContentDelta::TextDelta { text } => { diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 4c3b8b3..663464f 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel { // Azure OpenAI Streaming API // ----------------------------------------------------- impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 71a058b..b6302a0 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -12,7 +12,7 @@ use crate::json_utils::merge; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ completion::{self, CompletionError, CompletionModel, CompletionRequest}, extractor::ExtractorBuilder, @@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel { } impl StreamingCompletionModel for DeepSeekCompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge( diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 84d9bf8..77341d5 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -13,7 +13,7 @@ use super::openai; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index d9d35c2..ccfdeff 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -609,7 +609,7 @@ pub mod gemini_api_types { HarmCategoryCivicIntegrity, } - #[derive(Debug, Deserialize)] + #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index b3240bc..8e004af 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -2,26 +2,34 @@ use async_stream::stream; use futures::StreamExt; use serde::Deserialize; +use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; +use crate::providers::gemini::completion::gemini_api_types::UsageMetadata; use crate::{ completion::{CompletionError, CompletionRequest}, - streaming::{self, StreamingCompletionModel, StreamingResult}, + streaming::{self, StreamingCompletionModel}, }; -use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; - #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamGenerateContentResponse { /// Candidate responses from the model. pub candidates: Vec, pub model_version: Option, + pub usage_metadata: UsageMetadata, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage_metadata: UsageMetadata, } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let request = create_request_body(completion_request)?; let response = self @@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel { ))); } - Ok(Box::pin(stream! { + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { @@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel { => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), _ => panic!("Unsupported response type with streaming.") }; + + if choice.finish_reason.is_some() { + yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage_metadata: data.usage_metadata, + })) + } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 08ff2ee..ffdf29a 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -11,7 +11,7 @@ use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use crate::json_utils::merge; use crate::providers::openai; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index 8eea985..90ddd39 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -1,9 +1,9 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; -use crate::json_utils; use crate::json_utils::merge_inplace; -use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse}; +use crate::streaming::StreamingCompletionModel; +use crate::{json_utils, streaming}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::convert::Infallible; @@ -55,14 +55,19 @@ struct CompletionChunk { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let mut request = self.create_request_body(&completion_request)?; // Enable streaming - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); if let Some(ref params) = completion_request.additional_params { merge_inplace(&mut request, params.clone()); diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 00098e6..e0a69e2 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, AssistantContent}; use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; merge_inplace( diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 1f43eb6..16687bf 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -8,8 +8,9 @@ //! //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -24,7 +25,6 @@ use serde_json::{json, Value}; use std::string::FromUtf8Error; use thiserror::Error; use tracing; -use crate::providers::openai; #[derive(Debug, Error)] pub enum MiraError { @@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index 2906a89..9023133 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,7 +11,7 @@ use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 1429ad0..4c93475 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,7 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::json_utils::merge_inplace; -use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel { } } +#[derive(Clone)] pub struct StreamingCompletionResponse { - #[serde(default)] pub done_reason: Option, - #[serde(default)] pub total_duration: Option, - #[serde(default)] pub load_duration: Option, - #[serde(default)] pub prompt_eval_count: Option, - #[serde(default)] pub prompt_eval_duration: Option, - #[serde(default)] pub eval_count: Option, - #[serde(default)] pub eval_duration: Option, } impl StreamingCompletionModel for CompletionModel { - type Response = StreamingCompletionResponse; + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> + { let mut request_payload = self.create_completion_request(request)?; merge_inplace(&mut request_payload, json!({"stream": true})); @@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel { return Err(CompletionError::ProviderError(err_text)); } - let mut let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { @@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel { } } }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index af68d58..015bc8a 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -4,14 +4,13 @@ use crate::json_utils; use crate::json_utils::merge; use crate::providers::openai::Usage; use crate::streaming; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use async_stream::stream; use futures::StreamExt; use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; -use tokio::stream; // ================================================================ // OpenAI Completion Streaming API @@ -47,18 +46,21 @@ struct StreamingChoice { struct StreamingCompletionChunk { choices: Vec, usage: Option, + finish_reason: Option, } +#[derive(Clone)] pub struct StreamingCompletionResponse { - usage: Option, + pub usage: Option, } impl StreamingCompletionModel for CompletionModel { - type Response = StreamingCompletionResponse; + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> + { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); @@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request( yield Ok(streaming::RawStreamingChoice::Message(content.clone())) } - if &data.usage.is_some() { - usage = data.usage; + if data.finish_reason.is_some() { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: data.usage + })) } + } } @@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request( } }); - Ok(streaming::StreamingCompletionResponse::new( - inner, - )) + Ok(streaming::StreamingCompletionResponse::new(inner)) } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 1b3d1ea..c09c4d6 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -18,8 +18,9 @@ use crate::{ use crate::completion::CompletionRequest; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: completion::CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index dffb3b8..f437828 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -1,18 +1,21 @@ use serde_json::json; use super::completion::CompletionModel; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; +use crate::streaming::StreamingCompletionResponse; use crate::{ completion::{CompletionError, CompletionRequest}, json_utils::merge, - streaming::{StreamingCompletionModel, StreamingResult}, + streaming::StreamingCompletionModel, }; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream_tokens": true})); diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index 80ab430..b87a53e 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -1,15 +1,17 @@ use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; use crate::providers::xai::completion::CompletionModel; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use serde_json::json; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 62e7ec2..4d61116 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -14,6 +14,7 @@ use crate::completion::{ CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, }; use crate::message::AssistantContent; +use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; use std::fmt::{Display, Formatter}; @@ -23,7 +24,7 @@ use std::task::{Context, Poll}; /// Enum representing a streaming chunk from the model #[derive(Debug, Clone)] -pub enum RawStreamingChoice { +pub enum RawStreamingChoice { /// A text chunk from a message response Message(String), @@ -35,6 +36,7 @@ pub enum RawStreamingChoice { } /// Enum representing a streaming chunk from the model +#[derive(Debug, Clone)] pub enum StreamingChoice { /// A text chunk from a message response Message(String), @@ -61,7 +63,7 @@ pub type StreamingResult = #[cfg(target_arch = "wasm32")] pub type StreamingResult = Pin>>>; -pub struct StreamingCompletionResponse { +pub struct StreamingCompletionResponse { inner: StreamingResult, text: String, tool_calls: Vec<(String, String, serde_json::Value)>, @@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse { pub response: Option, } -impl StreamingCompletionResponse { +impl StreamingCompletionResponse { pub fn new(inner: StreamingResult) -> StreamingCompletionResponse { Self { inner, @@ -81,37 +83,43 @@ impl StreamingCompletionResponse { } } -impl Stream for StreamingCompletionResponse { +impl Stream for StreamingCompletionResponse { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.inner.poll_next(cx) { + let stream = self.get_mut(); + + match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => { - let content = vec![AssistantContent::text(self.text.clone())]; + let content = vec![AssistantContent::text(stream.text.clone())]; - self.tool_calls - .iter() - .for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a)); + stream.tool_calls.iter().for_each(|(n, d, a)| { + AssistantContent::tool_call(n, d, a.clone()); + }); - self.message = Message::Assistant { - content: content.into(), - } + stream.message = Message::Assistant { + content: OneOrMany::many(content) + .expect("There should be at least one assistant message"), + }; + + Poll::Ready(None) } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { - self.text = format!("{}{}", self.text, text); - Poll::Ready(Some(Ok(choice.clone()))) + stream.text = format!("{}{}", stream.text, text.clone()); + Poll::Ready(Some(Ok(StreamingChoice::Message(text)))) } RawStreamingChoice::ToolCall(name, description, args) => { - self.tool_calls - .push((name, description, args)); - Poll::Ready(Some(Ok(choice.clone()))) + stream + .tool_calls + .push((name.clone(), description.clone(), args.clone())); + Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args)))) } RawStreamingChoice::FinalResponse(response) => { - self.response = Some(response); + stream.response = Some(response); Poll::Pending } }, @@ -120,7 +128,7 @@ impl Stream for StreamingCompletionResponse { } /// Trait for high-level streaming prompt interface -pub trait StreamingPrompt: Send + Sync { +pub trait StreamingPrompt: Send + Sync { /// Stream a simple prompt to the model fn stream_prompt( &self, @@ -129,7 +137,7 @@ pub trait StreamingPrompt: Send + Sync { } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync { +pub trait StreamingChat: Send + Sync { /// Stream a chat with history to the model fn stream_chat( &self, @@ -150,18 +158,20 @@ pub trait StreamingCompletion { /// Trait defining a streaming completion model pub trait StreamingCompletionModel: CompletionModel { - type Response; + type StreamingResponse: Clone + Unpin; /// Stream a completion response for the given request fn stream( &self, request: CompletionRequest, - ) -> impl Future, CompletionError>>; + ) -> impl Future< + Output = Result, CompletionError>, + >; } /// helper function to stream a completion request to stdout -pub async fn stream_to_stdout( +pub async fn stream_to_stdout( agent: Agent, - stream: &mut StreamingCompletionResponse, + stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> { print!("Response: "); while let Some(chunk) = stream.next().await {