From fcbe648f779f773c5ada399792b7452e35d00ba6 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 8 Apr 2025 23:43:50 -0400 Subject: [PATCH 01/16] feat: start refactoring streaming api --- rig-core/src/providers/anthropic/streaming.rs | 8 +- rig-core/src/providers/azure.rs | 13 ++- rig-core/src/providers/deepseek.rs | 11 ++- rig-core/src/providers/galadriel.rs | 14 ++- rig-core/src/providers/gemini/streaming.rs | 4 +- rig-core/src/providers/groq.rs | 14 ++- rig-core/src/providers/hyperbolic.rs | 12 ++- rig-core/src/providers/mira.rs | 6 +- rig-core/src/providers/moonshot.rs | 14 ++- rig-core/src/providers/ollama.rs | 37 +++++-- rig-core/src/providers/openai/streaming.rs | 34 +++++-- rig-core/src/streaming.rs | 96 ++++++++++++++++--- 12 files changed, 210 insertions(+), 53 deletions(-) diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index b351515..85159f4 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -8,7 +8,7 @@ use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge_inplace; use crate::message::MessageError; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -191,12 +191,12 @@ impl StreamingCompletionModel for CompletionModel { fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, -) -> Option> { +) -> Option> { match event { StreamingEvent::ContentBlockDelta { delta, .. } => match delta { ContentDelta::TextDelta { text } => { if current_tool_call.is_none() { - return Some(Ok(StreamingChoice::Message(text.clone()))); + return Some(Ok(RawStreamingChoice::Message(text.clone()))); } None } @@ -227,7 +227,7 @@ fn handle_event( &tool_call.input_json }; match serde_json::from_str(json_str) { - Ok(json_value) => Some(Ok(StreamingChoice::ToolCall( + Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall( tool_call.name, tool_call.id, json_value, diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index c2e0ec9..4c3b8b3 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel { // Azure OpenAI Streaming API // ----------------------------------------------------- impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self .client diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 86d9a0f..71a058b 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -10,8 +10,9 @@ //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ completion::{self, CompletionError, CompletionModel, CompletionRequest}, extractor::ExtractorBuilder, @@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel { } impl StreamingCompletionModel for DeepSeekCompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/v1/chat/completions").json(&request); send_compatible_streaming_request(builder).await diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 3b536c9..84d9bf8 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -13,7 +13,7 @@ use super::openai; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 362d9ae..b3240bc 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -74,9 +74,9 @@ impl StreamingCompletionModel for CompletionModel { match choice.content.parts.first() { super::completion::gemini_api_types::Part::Text(text) - => yield Ok(streaming::StreamingChoice::Message(text)), + => yield Ok(streaming::RawStreamingChoice::Message(text)), super::completion::gemini_api_types::Part::FunctionCall(function_call) - => yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), + => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), _ => panic!("Unsupported response type with streaming.") }; } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index f852e77..08ff2ee 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -10,7 +10,8 @@ //! ``` use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 1606635..00098e6 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, AssistantContent}; use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); @@ -526,8 +530,10 @@ mod image_generation { // ====================================== // Hyperbolic Audio Generation API // ====================================== +use crate::providers::openai; #[cfg(feature = "audio")] pub use audio_generation::*; + #[cfg(feature = "audio")] mod audio_generation { use super::{ApiResponse, Client}; diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 1d11c19..1f43eb6 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -9,7 +9,7 @@ //! ``` use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -24,6 +24,7 @@ use serde_json::{json, Value}; use std::string::FromUtf8Error; use thiserror::Error; use tracing; +use crate::providers::openai; #[derive(Debug, Error)] pub enum MiraError { @@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type Response = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index 278863a..2906a89 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,7 +11,7 @@ use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 4b539b8..1429ad0 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,7 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -47,7 +47,7 @@ use crate::{ extractor::ExtractorBuilder, json_utils, message, message::{ImageDetail, Text}, - Embed, OneOrMany, + streaming, Embed, OneOrMany, }; use async_stream::stream; use futures::StreamExt; @@ -405,8 +405,30 @@ impl completion::CompletionModel for CompletionModel { } } +pub struct StreamingCompletionResponse { + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type Response = StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request_payload = self.create_completion_request(request)?; merge_inplace(&mut request_payload, json!({"stream": true})); @@ -426,7 +448,8 @@ impl StreamingCompletionModel for CompletionModel { return Err(CompletionError::ProviderError(err_text)); } - Ok(Box::pin(stream! { + let mut + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { @@ -456,13 +479,13 @@ impl StreamingCompletionModel for CompletionModel { match response.message { Message::Assistant{ content, tool_calls, .. } => { if !content.is_empty() { - yield Ok(StreamingChoice::Message(content)) + yield Ok(RawStreamingChoice::Message(content)) } for tool_call in tool_calls.iter() { let function = tool_call.function.clone(); - yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); + yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); } } _ => { @@ -471,7 +494,7 @@ impl StreamingCompletionModel for CompletionModel { } } } - })) + }); } } diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index 3727468..af68d58 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -2,6 +2,7 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils; use crate::json_utils::merge; +use crate::providers::openai::Usage; use crate::streaming; use crate::streaming::{StreamingCompletionModel, StreamingResult}; use async_stream::stream; @@ -10,6 +11,7 @@ use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; +use tokio::stream; // ================================================================ // OpenAI Completion Streaming API @@ -42,15 +44,21 @@ struct StreamingChoice { } #[derive(Deserialize)] -struct StreamingCompletionResponse { +struct StreamingCompletionChunk { choices: Vec, + usage: Option, +} + +pub struct StreamingCompletionResponse { + usage: Option, } impl StreamingCompletionModel for CompletionModel { + type Response = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); @@ -61,7 +69,7 @@ impl StreamingCompletionModel for CompletionModel { pub async fn send_compatible_streaming_request( request_builder: RequestBuilder, -) -> Result { +) -> Result, CompletionError> { let response = request_builder.send().await?; if !response.status().is_success() { @@ -73,7 +81,7 @@ pub async fn send_compatible_streaming_request( } // Handle OpenAI Compatible SSE chunks - Ok(Box::pin(stream! { + let inner = Box::pin(stream! { let mut stream = response.bytes_stream(); let mut partial_data = None; @@ -121,7 +129,7 @@ pub async fn send_compatible_streaming_request( } } - let data = serde_json::from_str::(&line); + let data = serde_json::from_str::(&line); let Ok(data) = data else { continue; @@ -162,13 +170,17 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } } } if let Some(content) = &choice.delta.content { - yield Ok(streaming::StreamingChoice::Message(content.clone())) + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + } + + if &data.usage.is_some() { + usage = data.usage; } } } @@ -178,7 +190,11 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new( + inner, + )) } diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 76f91c4..62e7ec2 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -13,14 +13,28 @@ use crate::agent::Agent; use crate::completion::{ CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, }; +use crate::message::AssistantContent; use futures::{Stream, StreamExt}; use std::boxed::Box; use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Enum representing a streaming chunk from the model +#[derive(Debug, Clone)] +pub enum RawStreamingChoice { + /// A text chunk from a message response + Message(String), + + /// A tool call response chunk + ToolCall(String, String, serde_json::Value), + + /// The final response object + FinalResponse(R), +} /// Enum representing a streaming chunk from the model -#[derive(Debug)] pub enum StreamingChoice { /// A text chunk from a message response Message(String), @@ -41,29 +55,87 @@ impl Display for StreamingChoice { } #[cfg(not(target_arch = "wasm32"))] -pub type StreamingResult = - Pin> + Send>>; +pub type StreamingResult = + Pin, CompletionError>> + Send>>; #[cfg(target_arch = "wasm32")] -pub type StreamingResult = Pin>>>; +pub type StreamingResult = Pin>>>; + +pub struct StreamingCompletionResponse { + inner: StreamingResult, + text: String, + tool_calls: Vec<(String, String, serde_json::Value)>, + pub message: Message, + pub response: Option, +} + +impl StreamingCompletionResponse { + pub fn new(inner: StreamingResult) -> StreamingCompletionResponse { + Self { + inner, + text: "".to_string(), + tool_calls: vec![], + message: Message::assistant(""), + response: None, + } + } +} + +impl Stream for StreamingCompletionResponse { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_next(cx) { + Poll::Pending => Poll::Pending, + + Poll::Ready(None) => { + let content = vec![AssistantContent::text(self.text.clone())]; + + self.tool_calls + .iter() + .for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a)); + + self.message = Message::Assistant { + content: content.into(), + } + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(choice))) => match choice { + RawStreamingChoice::Message(text) => { + self.text = format!("{}{}", self.text, text); + Poll::Ready(Some(Ok(choice.clone()))) + } + RawStreamingChoice::ToolCall(name, description, args) => { + self.tool_calls + .push((name, description, args)); + Poll::Ready(Some(Ok(choice.clone()))) + } + RawStreamingChoice::FinalResponse(response) => { + self.response = Some(response); + Poll::Pending + } + }, + } + } +} /// Trait for high-level streaming prompt interface -pub trait StreamingPrompt: Send + Sync { +pub trait StreamingPrompt: Send + Sync { /// Stream a simple prompt to the model fn stream_prompt( &self, prompt: &str, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync { +pub trait StreamingChat: Send + Sync { /// Stream a chat with history to the model fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for low-level streaming completion interface @@ -78,17 +150,18 @@ pub trait StreamingCompletion { /// Trait defining a streaming completion model pub trait StreamingCompletionModel: CompletionModel { + type Response; /// Stream a completion response for the given request fn stream( &self, request: CompletionRequest, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// helper function to stream a completion request to stdout -pub async fn stream_to_stdout( +pub async fn stream_to_stdout( agent: Agent, - stream: &mut StreamingResult, + stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> { print!("Response: "); while let Some(chunk) = stream.next().await { @@ -111,6 +184,7 @@ pub async fn stream_to_stdout( } } } + println!(); // New line after streaming completes Ok(()) From 86b84c82fb0a097c73e27fc700ca56b1bc055a83 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Wed, 9 Apr 2025 12:54:55 -0400 Subject: [PATCH 02/16] fix: compiles + formatted --- rig-core/examples/openai_streaming.rs | 4 ++ rig-core/src/agent.rs | 22 +++---- rig-core/src/completion/request.rs | 6 +- rig-core/src/providers/anthropic/streaming.rs | 39 +++++++++++-- rig-core/src/providers/azure.rs | 6 +- rig-core/src/providers/deepseek.rs | 6 +- rig-core/src/providers/galadriel.rs | 6 +- rig-core/src/providers/gemini/completion.rs | 2 +- rig-core/src/providers/gemini/streaming.rs | 28 +++++++-- rig-core/src/providers/groq.rs | 6 +- .../src/providers/huggingface/streaming.rs | 15 +++-- rig-core/src/providers/hyperbolic.rs | 6 +- rig-core/src/providers/mira.rs | 8 +-- rig-core/src/providers/moonshot.rs | 6 +- rig-core/src/providers/ollama.rs | 18 +++--- rig-core/src/providers/openai/streaming.rs | 23 ++++---- rig-core/src/providers/perplexity.rs | 6 +- rig-core/src/providers/together/streaming.rs | 7 ++- rig-core/src/providers/xai/streaming.rs | 6 +- rig-core/src/streaming.rs | 58 +++++++++++-------- 20 files changed, 175 insertions(+), 103 deletions(-) diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index d4aadf0..a474702 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> { .await?; stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; Ok(()) } diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 781e2a8..bf5ab0d 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -110,23 +110,20 @@ use std::collections::HashMap; use futures::{stream, StreamExt, TryStreamExt}; +use crate::streaming::StreamingCompletionResponse; +#[cfg(feature = "mcp")] +use crate::tool::McpTool; use crate::{ completion::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Message, Prompt, PromptError, }, message::AssistantContent, - streaming::{ - StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, - StreamingResult, - }, + streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt}, tool::{Tool, ToolSet}, vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; -#[cfg(feature = "mcp")] -use crate::tool::McpTool; - /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// (i.e.: system prompt) and a static set of context documents and tools. /// All context documents and tools are always provided to the agent when prompted. @@ -500,18 +497,21 @@ impl StreamingCompletion for Agent { } } -impl StreamingPrompt for Agent { - async fn stream_prompt(&self, prompt: &str) -> Result { +impl StreamingPrompt for Agent { + async fn stream_prompt( + &self, + prompt: &str, + ) -> Result, CompletionError> { self.stream_chat(prompt, vec![]).await } } -impl StreamingChat for Agent { +impl StreamingChat for Agent { async fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> Result { + ) -> Result, CompletionError> { self.stream_completion(prompt, chat_history) .await? .stream() diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..07e8f64 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -67,7 +67,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::OneOrMany; use crate::{ json_utils, @@ -467,7 +467,9 @@ impl CompletionRequestBuilder { impl CompletionRequestBuilder { /// Stream the completion request - pub async fn stream(self) -> Result { + pub async fn stream( + self, + ) -> Result, CompletionError> { let model = self.model.clone(); model.stream(self.build()).await } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 85159f4..c343b44 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge_inplace; use crate::message::MessageError; +use crate::streaming; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; #[derive(Debug, Deserialize)] @@ -61,7 +62,7 @@ pub struct MessageDelta { pub stop_sequence: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct PartialUsage { pub output_tokens: usize, #[serde(default)] @@ -75,11 +76,18 @@ struct ToolCallState { input_json: String, } +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: PartialUsage, +} + impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let max_tokens = if let Some(tokens) = completion_request.max_tokens { tokens } else if let Some(tokens) = self.default_max_tokens { @@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel { // Use our SSE decoder to directly handle Server-Sent Events format let sse_stream = sse_from_response(response); - Ok(Box::pin(stream! { + let stream: StreamingResult = Box::pin(stream! { let mut current_tool_call: Option = None; let mut sse_stream = Box::pin(sse_stream); + let mut input_tokens = 0; while let Some(sse_result) = sse_stream.next().await { match sse_result { @@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel { // Parse the SSE data as a StreamingEvent match serde_json::from_str::(&sse.data) { Ok(event) => { + match &event { + StreamingEvent::MessageStart { message } => { + input_tokens = message.usage.input_tokens; + }, + StreamingEvent::MessageDelta { delta, usage } => { + if delta.stop_reason.is_some() { + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: PartialUsage { + output_tokens: usage.output_tokens, + input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), + } + })) + } + } + _ => {} + } + if let Some(result) = handle_event(&event, &mut current_tool_call) { yield result; } @@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel { } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, -) -> Option> { +) -> Option, CompletionError>> { match event { StreamingEvent::ContentBlockDelta { delta, .. } => match delta { ContentDelta::TextDelta { text } => { diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 4c3b8b3..663464f 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel { // Azure OpenAI Streaming API // ----------------------------------------------------- impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 71a058b..b6302a0 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -12,7 +12,7 @@ use crate::json_utils::merge; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ completion::{self, CompletionError, CompletionModel, CompletionRequest}, extractor::ExtractorBuilder, @@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel { } impl StreamingCompletionModel for DeepSeekCompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge( diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 84d9bf8..77341d5 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -13,7 +13,7 @@ use super::openai; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index d9d35c2..ccfdeff 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -609,7 +609,7 @@ pub mod gemini_api_types { HarmCategoryCivicIntegrity, } - #[derive(Debug, Deserialize)] + #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index b3240bc..8e004af 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -2,26 +2,34 @@ use async_stream::stream; use futures::StreamExt; use serde::Deserialize; +use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; +use crate::providers::gemini::completion::gemini_api_types::UsageMetadata; use crate::{ completion::{CompletionError, CompletionRequest}, - streaming::{self, StreamingCompletionModel, StreamingResult}, + streaming::{self, StreamingCompletionModel}, }; -use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; - #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamGenerateContentResponse { /// Candidate responses from the model. pub candidates: Vec, pub model_version: Option, + pub usage_metadata: UsageMetadata, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage_metadata: UsageMetadata, } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let request = create_request_body(completion_request)?; let response = self @@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel { ))); } - Ok(Box::pin(stream! { + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { @@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel { => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), _ => panic!("Unsupported response type with streaming.") }; + + if choice.finish_reason.is_some() { + yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage_metadata: data.usage_metadata, + })) + } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 08ff2ee..ffdf29a 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -11,7 +11,7 @@ use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use crate::json_utils::merge; use crate::providers::openai; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index 8eea985..90ddd39 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -1,9 +1,9 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; -use crate::json_utils; use crate::json_utils::merge_inplace; -use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse}; +use crate::streaming::StreamingCompletionModel; +use crate::{json_utils, streaming}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::convert::Infallible; @@ -55,14 +55,19 @@ struct CompletionChunk { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let mut request = self.create_request_body(&completion_request)?; // Enable streaming - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); if let Some(ref params) = completion_request.additional_params { merge_inplace(&mut request, params.clone()); diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 00098e6..e0a69e2 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, AssistantContent}; use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; merge_inplace( diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 1f43eb6..16687bf 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -8,8 +8,9 @@ //! //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -24,7 +25,6 @@ use serde_json::{json, Value}; use std::string::FromUtf8Error; use thiserror::Error; use tracing; -use crate::providers::openai; #[derive(Debug, Error)] pub enum MiraError { @@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index 2906a89..9023133 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,7 +11,7 @@ use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - type Response = openai::StreamingCompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; request = merge( diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 1429ad0..4c93475 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,7 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::json_utils::merge_inplace; -use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel { } } +#[derive(Clone)] pub struct StreamingCompletionResponse { - #[serde(default)] pub done_reason: Option, - #[serde(default)] pub total_duration: Option, - #[serde(default)] pub load_duration: Option, - #[serde(default)] pub prompt_eval_count: Option, - #[serde(default)] pub prompt_eval_duration: Option, - #[serde(default)] pub eval_count: Option, - #[serde(default)] pub eval_duration: Option, } impl StreamingCompletionModel for CompletionModel { - type Response = StreamingCompletionResponse; + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> + { let mut request_payload = self.create_completion_request(request)?; merge_inplace(&mut request_payload, json!({"stream": true})); @@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel { return Err(CompletionError::ProviderError(err_text)); } - let mut let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { @@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel { } } }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index af68d58..015bc8a 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -4,14 +4,13 @@ use crate::json_utils; use crate::json_utils::merge; use crate::providers::openai::Usage; use crate::streaming; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use async_stream::stream; use futures::StreamExt; use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; -use tokio::stream; // ================================================================ // OpenAI Completion Streaming API @@ -47,18 +46,21 @@ struct StreamingChoice { struct StreamingCompletionChunk { choices: Vec, usage: Option, + finish_reason: Option, } +#[derive(Clone)] pub struct StreamingCompletionResponse { - usage: Option, + pub usage: Option, } impl StreamingCompletionModel for CompletionModel { - type Response = StreamingCompletionResponse; + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result, CompletionError> + { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); @@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request( yield Ok(streaming::RawStreamingChoice::Message(content.clone())) } - if &data.usage.is_some() { - usage = data.usage; + if data.finish_reason.is_some() { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: data.usage + })) } + } } @@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request( } }); - Ok(streaming::StreamingCompletionResponse::new( - inner, - )) + Ok(streaming::StreamingCompletionResponse::new(inner)) } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 1b3d1ea..c09c4d6 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -18,8 +18,9 @@ use crate::{ use crate::completion::CompletionRequest; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: completion::CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index dffb3b8..f437828 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -1,18 +1,21 @@ use serde_json::json; use super::completion::CompletionModel; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; +use crate::streaming::StreamingCompletionResponse; use crate::{ completion::{CompletionError, CompletionRequest}, json_utils::merge, - streaming::{StreamingCompletionModel, StreamingResult}, + streaming::StreamingCompletionModel, }; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream_tokens": true})); diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index 80ab430..b87a53e 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -1,15 +1,17 @@ use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; use crate::providers::xai::completion::CompletionModel; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use serde_json::json; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 62e7ec2..4d61116 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -14,6 +14,7 @@ use crate::completion::{ CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, }; use crate::message::AssistantContent; +use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; use std::fmt::{Display, Formatter}; @@ -23,7 +24,7 @@ use std::task::{Context, Poll}; /// Enum representing a streaming chunk from the model #[derive(Debug, Clone)] -pub enum RawStreamingChoice { +pub enum RawStreamingChoice { /// A text chunk from a message response Message(String), @@ -35,6 +36,7 @@ pub enum RawStreamingChoice { } /// Enum representing a streaming chunk from the model +#[derive(Debug, Clone)] pub enum StreamingChoice { /// A text chunk from a message response Message(String), @@ -61,7 +63,7 @@ pub type StreamingResult = #[cfg(target_arch = "wasm32")] pub type StreamingResult = Pin>>>; -pub struct StreamingCompletionResponse { +pub struct StreamingCompletionResponse { inner: StreamingResult, text: String, tool_calls: Vec<(String, String, serde_json::Value)>, @@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse { pub response: Option, } -impl StreamingCompletionResponse { +impl StreamingCompletionResponse { pub fn new(inner: StreamingResult) -> StreamingCompletionResponse { Self { inner, @@ -81,37 +83,43 @@ impl StreamingCompletionResponse { } } -impl Stream for StreamingCompletionResponse { +impl Stream for StreamingCompletionResponse { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.inner.poll_next(cx) { + let stream = self.get_mut(); + + match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => { - let content = vec![AssistantContent::text(self.text.clone())]; + let content = vec![AssistantContent::text(stream.text.clone())]; - self.tool_calls - .iter() - .for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a)); + stream.tool_calls.iter().for_each(|(n, d, a)| { + AssistantContent::tool_call(n, d, a.clone()); + }); - self.message = Message::Assistant { - content: content.into(), - } + stream.message = Message::Assistant { + content: OneOrMany::many(content) + .expect("There should be at least one assistant message"), + }; + + Poll::Ready(None) } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { - self.text = format!("{}{}", self.text, text); - Poll::Ready(Some(Ok(choice.clone()))) + stream.text = format!("{}{}", stream.text, text.clone()); + Poll::Ready(Some(Ok(StreamingChoice::Message(text)))) } RawStreamingChoice::ToolCall(name, description, args) => { - self.tool_calls - .push((name, description, args)); - Poll::Ready(Some(Ok(choice.clone()))) + stream + .tool_calls + .push((name.clone(), description.clone(), args.clone())); + Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args)))) } RawStreamingChoice::FinalResponse(response) => { - self.response = Some(response); + stream.response = Some(response); Poll::Pending } }, @@ -120,7 +128,7 @@ impl Stream for StreamingCompletionResponse { } /// Trait for high-level streaming prompt interface -pub trait StreamingPrompt: Send + Sync { +pub trait StreamingPrompt: Send + Sync { /// Stream a simple prompt to the model fn stream_prompt( &self, @@ -129,7 +137,7 @@ pub trait StreamingPrompt: Send + Sync { } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync { +pub trait StreamingChat: Send + Sync { /// Stream a chat with history to the model fn stream_chat( &self, @@ -150,18 +158,20 @@ pub trait StreamingCompletion { /// Trait defining a streaming completion model pub trait StreamingCompletionModel: CompletionModel { - type Response; + type StreamingResponse: Clone + Unpin; /// Stream a completion response for the given request fn stream( &self, request: CompletionRequest, - ) -> impl Future, CompletionError>>; + ) -> impl Future< + Output = Result, CompletionError>, + >; } /// helper function to stream a completion request to stdout -pub async fn stream_to_stdout( +pub async fn stream_to_stdout( agent: Agent, - stream: &mut StreamingCompletionResponse, + stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> { print!("Response: "); while let Some(chunk) = stream.next().await { From 2920eb0a0e91f06034e5678aa8331f546d34d998 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Wed, 9 Apr 2025 15:33:26 -0400 Subject: [PATCH 03/16] feat: update examples --- rig-core/examples/anthropic_streaming.rs | 8 ++ .../anthropic_streaming_with_tools.rs | 7 ++ rig-core/examples/gemini_streaming.rs | 5 + .../examples/gemini_streaming_with_tools.rs | 7 ++ rig-core/examples/ollama_streaming.rs | 5 + .../examples/ollama_streaming_with_tools.rs | 7 ++ rig-core/examples/openai_streaming.rs | 2 + .../examples/openai_streaming_with_tools.rs | 7 ++ rig-core/src/providers/gemini/completion.rs | 2 +- rig-core/src/providers/gemini/streaming.rs | 14 ++- rig-core/src/providers/ollama.rs | 12 +++ rig-core/src/providers/openai/streaming.rs | 96 ++++++++++--------- rig-core/src/streaming.rs | 17 ++-- 13 files changed, 134 insertions(+), 55 deletions(-) diff --git a/rig-core/examples/anthropic_streaming.rs b/rig-core/examples/anthropic_streaming.rs index 349a45d..015189a 100644 --- a/rig-core/examples/anthropic_streaming.rs +++ b/rig-core/examples/anthropic_streaming.rs @@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.message); + + Ok(()) } diff --git a/rig-core/examples/anthropic_streaming_with_tools.rs b/rig-core/examples/anthropic_streaming_with_tools.rs index ec3ee7c..9dc53b4 100644 --- a/rig-core/examples/anthropic_streaming_with_tools.rs +++ b/rig-core/examples/anthropic_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/gemini_streaming.rs b/rig-core/examples/gemini_streaming.rs index 1ff711b..34f57bd 100644 --- a/rig-core/examples/gemini_streaming.rs +++ b/rig-core/examples/gemini_streaming.rs @@ -19,5 +19,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/gemini_streaming_with_tools.rs b/rig-core/examples/gemini_streaming_with_tools.rs index 43f469d..a5fb1c9 100644 --- a/rig-core/examples/gemini_streaming_with_tools.rs +++ b/rig-core/examples/gemini_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/ollama_streaming.rs b/rig-core/examples/ollama_streaming.rs index fe12467..1bc1e33 100644 --- a/rig-core/examples/ollama_streaming.rs +++ b/rig-core/examples/ollama_streaming.rs @@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/ollama_streaming_with_tools.rs b/rig-core/examples/ollama_streaming_with_tools.rs index 0e59549..4b4427f 100644 --- a/rig-core/examples/ollama_streaming_with_tools.rs +++ b/rig-core/examples/ollama_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index a474702..f668036 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -20,6 +20,8 @@ async fn main() -> Result<(), anyhow::Error> { if let Some(response) = stream.response { println!("Usage: {:?}", response.usage) }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/openai_streaming_with_tools.rs b/rig-core/examples/openai_streaming_with_tools.rs index 997bebb..72c4aeb 100644 --- a/rig-core/examples/openai_streaming_with_tools.rs +++ b/rig-core/examples/openai_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index ccfdeff..7491fc5 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -609,7 +609,7 @@ pub mod gemini_api_types { HarmCategoryCivicIntegrity, } - #[derive(Debug, Deserialize, Clone)] + #[derive(Debug, Deserialize, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 8e004af..66a6a82 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -9,18 +9,24 @@ use crate::{ streaming::{self, StreamingCompletionModel}, }; +#[derive(Debug, Deserialize, Default, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PartialUsage { + pub total_token_count: i32, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamGenerateContentResponse { /// Candidate responses from the model. pub candidates: Vec, pub model_version: Option, - pub usage_metadata: UsageMetadata, + pub usage_metadata: Option, } #[derive(Clone)] pub struct StreamingCompletionResponse { - pub usage_metadata: UsageMetadata, + pub usage_metadata: PartialUsage, } impl StreamingCompletionModel for CompletionModel { @@ -90,7 +96,9 @@ impl StreamingCompletionModel for CompletionModel { if choice.finish_reason.is_some() { yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage_metadata: data.usage_metadata, + usage_metadata: PartialUsage { + total_token_count: data.usage_metadata.unwrap().total_token_count, + } })) } } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 4c93475..a754fe2 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -486,6 +486,18 @@ impl StreamingCompletionModel for CompletionModel { continue; } } + + if response.done { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + total_duration: response.total_duration, + load_duration: response.load_duration, + prompt_eval_count: response.prompt_eval_count, + prompt_eval_duration: response.prompt_eval_duration, + eval_count: response.eval_count, + eval_duration: response.eval_duration, + done_reason: response.done_reason, + })); + } } } }); diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index 015bc8a..c42f1a4 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -46,12 +46,11 @@ struct StreamingChoice { struct StreamingCompletionChunk { choices: Vec, usage: Option, - finish_reason: Option, } #[derive(Clone)] pub struct StreamingCompletionResponse { - pub usage: Option, + pub usage: Usage, } impl StreamingCompletionModel for CompletionModel { @@ -62,7 +61,10 @@ impl StreamingCompletionModel for CompletionModel { ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await @@ -86,6 +88,11 @@ pub async fn send_compatible_streaming_request( let inner = Box::pin(stream! { let mut stream = response.bytes_stream(); + let mut final_usage = Usage { + prompt_tokens: 0, + total_tokens: 0 + }; + let mut partial_data = None; let mut calls: HashMap = HashMap::new(); @@ -110,8 +117,6 @@ pub async fn send_compatible_streaming_request( for line in text.lines() { let mut line = line.to_string(); - - // If there was a remaining part, concat with current line if partial_data.is_some() { line = format!("{}{}", partial_data.unwrap(), line); @@ -137,56 +142,53 @@ pub async fn send_compatible_streaming_request( continue; }; - let choice = data.choices.first().expect("Should have at least one choice"); - let delta = &choice.delta; + if let Some(choice) = data.choices.first() { - if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let function = tool_call.function.clone(); + let delta = &choice.delta; - // Start of tool call - // name: Some(String) - // arguments: None - if function.name.is_some() && function.arguments.is_empty() { - calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); - } - // Part of tool call - // name: None - // arguments: Some(String) - else if function.name.is_none() && !function.arguments.is_empty() { - let Some((name, arguments)) = calls.get(&tool_call.index) else { - continue; - }; + if !delta.tool_calls.is_empty() { + for tool_call in &delta.tool_calls { + let function = tool_call.function.clone(); - let new_arguments = &tool_call.function.arguments; - let arguments = format!("{}{}", arguments, new_arguments); + // Start of tool call + // name: Some(String) + // arguments: None + if function.name.is_some() && function.arguments.is_empty() { + calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); + } + // Part of tool call + // name: None + // arguments: Some(String) + else if function.name.is_none() && !function.arguments.is_empty() { + let Some((name, arguments)) = calls.get(&tool_call.index) else { + continue; + }; - calls.insert(tool_call.index, (name.clone(), arguments)); - } - // Entire tool call - else { - let name = function.name.unwrap(); - let arguments = function.arguments; - let Ok(arguments) = serde_json::from_str(&arguments) else { - continue; - }; + let new_arguments = &tool_call.function.arguments; + let arguments = format!("{}{}", arguments, new_arguments); - yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + calls.insert(tool_call.index, (name.clone(), arguments)); + } + // Entire tool call + else { + let name = function.name.unwrap(); + let arguments = function.arguments; + let Ok(arguments) = serde_json::from_str(&arguments) else { + continue; + }; + + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + } } } + + } - if let Some(content) = &choice.delta.content { - yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + if let Some(usage) = data.usage { + final_usage = usage.clone(); } - - if data.finish_reason.is_some() { - yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage: data.usage - })) - } - } } @@ -195,8 +197,12 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: final_usage.clone() + })) }); Ok(streaming::StreamingCompletionResponse::new(inner)) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 4d61116..5edc2c3 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -91,21 +91,25 @@ impl Stream for StreamingCompletionResponse { match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(None) => { - let content = vec![AssistantContent::text(stream.text.clone())]; + + let mut content = vec![]; stream.tool_calls.iter().for_each(|(n, d, a)| { - AssistantContent::tool_call(n, d, a.clone()); + content.push(AssistantContent::tool_call(n, d, a.clone())); }); + if content.len() == 0 || stream.text.len() > 0 { + content.insert(0, AssistantContent::text(stream.text.clone())); + } + stream.message = Message::Assistant { content: OneOrMany::many(content) .expect("There should be at least one assistant message"), }; - + Poll::Ready(None) - } + }, Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { @@ -120,7 +124,8 @@ impl Stream for StreamingCompletionResponse { } RawStreamingChoice::FinalResponse(response) => { stream.response = Some(response); - Poll::Pending + + stream.poll_next_unpin(cx) } }, } From cad584455aa98ba2e54555bc13cfbf76471d6ccb Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Wed, 9 Apr 2025 15:40:57 -0400 Subject: [PATCH 04/16] fix: wasm missing generic + fmt --- rig-core/examples/anthropic_streaming.rs | 4 +--- rig-core/examples/gemini_streaming.rs | 5 ++++- rig-core/examples/gemini_streaming_with_tools.rs | 7 +++++-- rig-core/examples/ollama_streaming_with_tools.rs | 2 +- rig-core/examples/openai_streaming.rs | 4 ++-- rig-core/examples/openai_streaming_with_tools.rs | 4 ++-- rig-core/src/providers/gemini/streaming.rs | 1 - rig-core/src/providers/ollama.rs | 2 +- rig-core/src/providers/openai/streaming.rs | 7 +++++-- rig-core/src/streaming.rs | 10 +++++----- 10 files changed, 26 insertions(+), 20 deletions(-) diff --git a/rig-core/examples/anthropic_streaming.rs b/rig-core/examples/anthropic_streaming.rs index 015189a..b61ef96 100644 --- a/rig-core/examples/anthropic_streaming.rs +++ b/rig-core/examples/anthropic_streaming.rs @@ -19,13 +19,11 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; - if let Some(response) = stream.response { println!("Usage: {:?} tokens", response.usage.output_tokens); }; println!("Message: {:?}", stream.message); - - + Ok(()) } diff --git a/rig-core/examples/gemini_streaming.rs b/rig-core/examples/gemini_streaming.rs index 34f57bd..fc22907 100644 --- a/rig-core/examples/gemini_streaming.rs +++ b/rig-core/examples/gemini_streaming.rs @@ -20,7 +20,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; if let Some(response) = stream.response { - println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + println!( + "Usage: {:?} tokens", + response.usage_metadata.total_token_count + ); }; println!("Message: {:?}", stream.message); diff --git a/rig-core/examples/gemini_streaming_with_tools.rs b/rig-core/examples/gemini_streaming_with_tools.rs index a5fb1c9..8c9bffe 100644 --- a/rig-core/examples/gemini_streaming_with_tools.rs +++ b/rig-core/examples/gemini_streaming_with_tools.rs @@ -109,10 +109,13 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(calculator_agent, &mut stream).await?; if let Some(response) = stream.response { - println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + println!( + "Usage: {:?} tokens", + response.usage_metadata.total_token_count + ); }; println!("Message: {:?}", stream.message); - + Ok(()) } diff --git a/rig-core/examples/ollama_streaming_with_tools.rs b/rig-core/examples/ollama_streaming_with_tools.rs index 4b4427f..37b5854 100644 --- a/rig-core/examples/ollama_streaming_with_tools.rs +++ b/rig-core/examples/ollama_streaming_with_tools.rs @@ -113,6 +113,6 @@ async fn main() -> Result<(), anyhow::Error> { }; println!("Message: {:?}", stream.message); - + Ok(()) } diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index f668036..40ec569 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -16,11 +16,11 @@ async fn main() -> Result<(), anyhow::Error> { .await?; stream_to_stdout(agent, &mut stream).await?; - + if let Some(response) = stream.response { println!("Usage: {:?}", response.usage) }; - + println!("Message: {:?}", stream.message); Ok(()) diff --git a/rig-core/examples/openai_streaming_with_tools.rs b/rig-core/examples/openai_streaming_with_tools.rs index 72c4aeb..b83a1fa 100644 --- a/rig-core/examples/openai_streaming_with_tools.rs +++ b/rig-core/examples/openai_streaming_with_tools.rs @@ -107,12 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; - + if let Some(response) = stream.response { println!("Usage: {:?}", response.usage) }; println!("Message: {:?}", stream.message); - + Ok(()) } diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 66a6a82..e48c0b8 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -3,7 +3,6 @@ use futures::StreamExt; use serde::Deserialize; use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; -use crate::providers::gemini::completion::gemini_api_types::UsageMetadata; use crate::{ completion::{CompletionError, CompletionRequest}, streaming::{self, StreamingCompletionModel}, diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index a754fe2..52df410 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -486,7 +486,7 @@ impl StreamingCompletionModel for CompletionModel { continue; } } - + if response.done { yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { total_duration: response.total_duration, diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index c42f1a4..c3e83ac 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -92,7 +92,7 @@ pub async fn send_compatible_streaming_request( prompt_tokens: 0, total_tokens: 0 }; - + let mut partial_data = None; let mut calls: HashMap = HashMap::new(); @@ -183,9 +183,12 @@ pub async fn send_compatible_streaming_request( } } - + if let Some(content) = &choice.delta.content { + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + } } + if let Some(usage) = data.usage { final_usage = usage.clone(); } diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 5edc2c3..9eff6ab 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -61,7 +61,8 @@ pub type StreamingResult = Pin, CompletionError>> + Send>>; #[cfg(target_arch = "wasm32")] -pub type StreamingResult = Pin>>>; +pub type StreamingResult = + Pin, CompletionError>>>>; pub struct StreamingCompletionResponse { inner: StreamingResult, @@ -92,14 +93,13 @@ impl Stream for StreamingCompletionResponse { match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => { - let mut content = vec![]; stream.tool_calls.iter().for_each(|(n, d, a)| { content.push(AssistantContent::tool_call(n, d, a.clone())); }); - if content.len() == 0 || stream.text.len() > 0 { + if content.is_empty() || !stream.text.is_empty() { content.insert(0, AssistantContent::text(stream.text.clone())); } @@ -107,9 +107,9 @@ impl Stream for StreamingCompletionResponse { content: OneOrMany::many(content) .expect("There should be at least one assistant message"), }; - + Poll::Ready(None) - }, + } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { From 0abd7b8c7629363f40c32a00814723ea48965655 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Thu, 10 Apr 2025 19:55:42 -0400 Subject: [PATCH 05/16] feat: cohere streaming + unify StreamingChoice w/ message --- rig-core/examples/cohere_streaming.rs | 27 +++ .../examples/cohere_streaming_with_tools.rs | 118 +++++++++++ rig-core/src/providers/cohere/completion.rs | 46 +++-- rig-core/src/providers/cohere/mod.rs | 3 +- rig-core/src/providers/cohere/streaming.rs | 195 ++++++++++++++++++ rig-core/src/streaming.rs | 42 ++-- 6 files changed, 378 insertions(+), 53 deletions(-) create mode 100644 rig-core/examples/cohere_streaming.rs create mode 100644 rig-core/examples/cohere_streaming_with_tools.rs create mode 100644 rig-core/src/providers/cohere/streaming.rs diff --git a/rig-core/examples/cohere_streaming.rs b/rig-core/examples/cohere_streaming.rs new file mode 100644 index 0000000..bf1f16c --- /dev/null +++ b/rig-core/examples/cohere_streaming.rs @@ -0,0 +1,27 @@ +use rig::providers::cohere; +use rig::streaming::{stream_to_stdout, StreamingPrompt}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create streaming agent with a single context prompt + let agent = cohere::Client::from_env() + .agent(cohere::COMMAND) + .preamble("Be precise and concise.") + .temperature(0.5) + .build(); + + // Stream the response and print chunks as they arrive + let mut stream = agent + .stream_prompt("When and where and what type is the next solar eclipse?") + .await?; + + stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage); + }; + + println!("Message: {:?}", stream.message); + + Ok(()) +} diff --git a/rig-core/examples/cohere_streaming_with_tools.rs b/rig-core/examples/cohere_streaming_with_tools.rs new file mode 100644 index 0000000..c4d5961 --- /dev/null +++ b/rig-core/examples/cohere_streaming_with_tools.rs @@ -0,0 +1,118 @@ +use anyhow::Result; +use rig::streaming::stream_to_stdout; +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"] + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let calculator_agent = providers::cohere::Client::from_env() + .agent(providers::cohere::COMMAND_R) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + println!("Calculate 2 - 5"); + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; + stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage); + }; + + println!("Message: {:?}", stream.message); + + Ok(()) +} diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index 9621bf7..ca41692 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -6,8 +6,9 @@ use crate::{ }; use super::client::Client; +use crate::completion::CompletionRequest; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; #[derive(Debug, Deserialize)] pub struct CompletionResponse { @@ -419,7 +420,7 @@ impl TryFrom for message::Message { #[derive(Clone)] pub struct CompletionModel { - client: Client, + pub(crate) client: Client, pub model: String, } @@ -430,16 +431,11 @@ impl CompletionModel { model: model.to_string(), } } -} -impl completion::CompletionModel for CompletionModel { - type Response = CompletionResponse; - - #[cfg_attr(feature = "worker", worker::send)] - async fn completion( + pub(crate) fn create_completion_request( &self, - completion_request: completion::CompletionRequest, - ) -> Result, CompletionError> { + completion_request: CompletionRequest, + ) -> Result { let prompt = completion_request.prompt_with_context(); let mut messages: Vec = @@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel { "tools": completion_request.tools.into_iter().map(Tool::from).collect::>(), }); + if let Some(ref params) = completion_request.additional_params { + Ok(json_utils::merge(request.clone(), params.clone())) + } else { + Ok(request) + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let request = self.create_completion_request(completion_request)?; tracing::debug!( "Cohere request: {}", serde_json::to_string_pretty(&request)? ); - let response = self - .client - .post("/v2/chat") - .json( - &if let Some(ref params) = completion_request.additional_params { - json_utils::merge(request.clone(), params.clone()) - } else { - request.clone() - }, - ) - .send() - .await?; + let response = self.client.post("/v2/chat").json(&request).send().await?; if response.status().is_success() { let text_response = response.text().await?; diff --git a/rig-core/src/providers/cohere/mod.rs b/rig-core/src/providers/cohere/mod.rs index 4c290c0..b88dc32 100644 --- a/rig-core/src/providers/cohere/mod.rs +++ b/rig-core/src/providers/cohere/mod.rs @@ -12,6 +12,7 @@ pub mod client; pub mod completion; pub mod embeddings; +pub mod streaming; pub use client::Client; pub use client::{ApiErrorResponse, ApiResponse}; @@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel; // ================================================================ /// `command-r-plus` completion model -pub const COMMAND_R_PLUS: &str = "comman-r-plus"; +pub const COMMAND_R_PLUS: &str = "command-r-plus"; /// `command-r` completion model pub const COMMAND_R: &str = "command-r"; /// `command` completion model diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs new file mode 100644 index 0000000..d097bc8 --- /dev/null +++ b/rig-core/src/providers/cohere/streaming.rs @@ -0,0 +1,195 @@ +use crate::completion::{CompletionError, CompletionRequest}; +use crate::message::{ToolCall, ToolFunction}; +use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage}; +use crate::providers::cohere::CompletionModel; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; +use crate::{json_utils, streaming}; +use async_stream::stream; +use futures::StreamExt; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::future::Future; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case", tag = "type")] +pub enum StreamingEvent { + MessageStart, + ContentStart, + ContentDelta { delta: Option }, + ContentEnd, + ToolPlan { delta: Option }, + ToolCallStart { delta: Option }, + ToolCallDelta { delta: Option }, + ToolCallEnd { delta: Option }, + MessageEnd { delta: Option }, +} + +#[derive(Debug, Deserialize)] +struct MessageContentDelta { + r#type: Option, + text: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageToolFunctionDelta { + name: Option, + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageToolCallDelta { + id: Option, + r#type: Option, + function: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageDelta { + content: Option, + tool_plan: Option, + tool_calls: Option, +} + +#[derive(Debug, Deserialize)] +struct Delta { + message: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageEndDelta { + finish_reason: Option, + usage: Option, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: Option, +} + +impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> + { + let request = self.create_completion_request(request)?; + let request = json_utils::merge(request, json!({"stream": true})); + + tracing::debug!( + "Cohere request: {}", + serde_json::to_string_pretty(&request)? + ); + + let response = self.client.post("/v2/chat").json(&request).send().await?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError(format!( + "{}: {}", + response.status(), + response.text().await? + ))); + } + + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + let mut current_tool_call: Option<(String, String, String)> = None; + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(CompletionError::from(e)); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + yield Err(CompletionError::ResponseError(e.to_string())); + break; + } + }; + + for line in text.lines() { + + let Some(line) = line.strip_prefix("data: ") else { + continue; + }; + + let event = { + let result = serde_json::from_str::(&line); + + let Ok(event) = result else { + continue; + }; + + event + }; + + match event { + StreamingEvent::ContentDelta { delta: Some(delta) } => { + let Some(message) = &delta.message else { continue; }; + let Some(content) = &message.content else { continue; }; + let Some(text) = &content.text else { continue; }; + + yield Ok(RawStreamingChoice::Message(text.clone())); + }, + StreamingEvent::MessageEnd {delta: Some(delta)} => { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: delta.usage.clone() + })); + }, + StreamingEvent::ToolCallStart { delta: Some(delta)} => { + // Skip the delta if there's any missing information, + // though this *should* all be present + let Some(message) = &delta.message else { continue; }; + let Some(tool_calls) = &message.tool_calls else { continue; }; + let Some(id) = tool_calls.id.clone() else { continue; }; + let Some(function) = &tool_calls.function else { continue; }; + let Some(name) = function.name.clone() else { continue; }; + let Some(arguments) = function.arguments.clone() else { continue; }; + + current_tool_call = Some((id, name, arguments)); + }, + StreamingEvent::ToolCallDelta { delta: Some(delta)} => { + // Skip the delta if there's any missing information, + // though this *should* all be present + let Some(message) = &delta.message else { continue; }; + let Some(tool_calls) = &message.tool_calls else { continue; }; + let Some(function) = &tool_calls.function else { continue; }; + let Some(arguments) = function.arguments.clone() else { continue; }; + + if let Some(tc) = current_tool_call.clone() { + current_tool_call = Some(( + tc.0, + tc.1, + format!("{}{}", tc.2, arguments) + )); + }; + }, + StreamingEvent::ToolCallEnd { .. } => { + let Some(tc) = current_tool_call.clone() else { continue; }; + + let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; + + yield Ok(RawStreamingChoice::ToolCall( + tc.0, + tc.1, + args + )); + + current_tool_call = None; + }, + _ => {} + }; + } + } + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) + } +} diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 9eff6ab..c4bcdef 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -35,27 +35,6 @@ pub enum RawStreamingChoice { FinalResponse(R), } -/// Enum representing a streaming chunk from the model -#[derive(Debug, Clone)] -pub enum StreamingChoice { - /// A text chunk from a message response - Message(String), - - /// A tool call response chunk - ToolCall(String, String, serde_json::Value), -} - -impl Display for StreamingChoice { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - StreamingChoice::Message(text) => write!(f, "{}", text), - StreamingChoice::ToolCall(name, id, params) => { - write!(f, "Tool call: {} {} {:?}", name, id, params) - } - } - } -} - #[cfg(not(target_arch = "wasm32"))] pub type StreamingResult = Pin, CompletionError>> + Send>>; @@ -85,7 +64,7 @@ impl StreamingCompletionResponse { } impl Stream for StreamingCompletionResponse { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream = self.get_mut(); @@ -114,13 +93,13 @@ impl Stream for StreamingCompletionResponse { Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { stream.text = format!("{}{}", stream.text, text.clone()); - Poll::Ready(Some(Ok(StreamingChoice::Message(text)))) + Poll::Ready(Some(Ok(AssistantContent::text(text)))) } - RawStreamingChoice::ToolCall(name, description, args) => { + RawStreamingChoice::ToolCall(id, name, args) => { stream .tool_calls - .push((name.clone(), description.clone(), args.clone())); - Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args)))) + .push((id.clone(), name.clone(), args.clone())); + Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args)))) } RawStreamingChoice::FinalResponse(response) => { stream.response = Some(response); @@ -181,14 +160,17 @@ pub async fn stream_to_stdout( print!("Response: "); while let Some(chunk) = stream.next().await { match chunk { - Ok(StreamingChoice::Message(text)) => { - print!("{}", text); + Ok(AssistantContent::Text(text)) => { + print!("{}", text.text); std::io::Write::flush(&mut std::io::stdout())?; } - Ok(StreamingChoice::ToolCall(name, _, params)) => { + Ok(AssistantContent::ToolCall(tool_call)) => { let res = agent .tools - .call(&name, params.to_string()) + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) .await .map_err(|e| std::io::Error::other(e.to_string()))?; println!("\nResult: {}", res); From 4f023046d54642fb368e15abda41bcc310e0017f Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Thu, 10 Apr 2025 19:59:44 -0400 Subject: [PATCH 06/16] chore: fix clippy --- rig-core/src/providers/cohere/streaming.rs | 33 +++++++++------------- rig-core/src/streaming.rs | 1 - 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index d097bc8..c1e6938 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -1,33 +1,29 @@ use crate::completion::{CompletionError, CompletionRequest}; -use crate::message::{ToolCall, ToolFunction}; -use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage}; +use crate::providers::cohere::completion::Usage; use crate::providers::cohere::CompletionModel; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use crate::{json_utils, streaming}; use async_stream::stream; use futures::StreamExt; use serde::Deserialize; -use serde_json::{json, Value}; -use std::collections::HashMap; -use std::future::Future; +use serde_json::json; #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case", tag = "type")] -pub enum StreamingEvent { +enum StreamingEvent { MessageStart, ContentStart, ContentDelta { delta: Option }, ContentEnd, - ToolPlan { delta: Option }, + ToolPlan, ToolCallStart { delta: Option }, ToolCallDelta { delta: Option }, - ToolCallEnd { delta: Option }, + ToolCallEnd, MessageEnd { delta: Option }, } #[derive(Debug, Deserialize)] struct MessageContentDelta { - r#type: Option, text: Option, } @@ -40,14 +36,12 @@ struct MessageToolFunctionDelta { #[derive(Debug, Deserialize)] struct MessageToolCallDelta { id: Option, - r#type: Option, function: Option, } #[derive(Debug, Deserialize)] struct MessageDelta { content: Option, - tool_plan: Option, tool_calls: Option, } @@ -58,7 +52,6 @@ struct Delta { #[derive(Debug, Deserialize)] struct MessageEndDelta { - finish_reason: Option, usage: Option, } @@ -119,9 +112,9 @@ impl StreamingCompletionModel for CompletionModel { let Some(line) = line.strip_prefix("data: ") else { continue; }; - + let event = { - let result = serde_json::from_str::(&line); + let result = serde_json::from_str::(line); let Ok(event) = result else { continue; @@ -129,13 +122,13 @@ impl StreamingCompletionModel for CompletionModel { event }; - + match event { StreamingEvent::ContentDelta { delta: Some(delta) } => { let Some(message) = &delta.message else { continue; }; let Some(content) = &message.content else { continue; }; let Some(text) = &content.text else { continue; }; - + yield Ok(RawStreamingChoice::Message(text.clone())); }, StreamingEvent::MessageEnd {delta: Some(delta)} => { @@ -145,7 +138,7 @@ impl StreamingCompletionModel for CompletionModel { }, StreamingEvent::ToolCallStart { delta: Some(delta)} => { // Skip the delta if there's any missing information, - // though this *should* all be present + // though this *should* all be present let Some(message) = &delta.message else { continue; }; let Some(tool_calls) = &message.tool_calls else { continue; }; let Some(id) = tool_calls.id.clone() else { continue; }; @@ -173,15 +166,15 @@ impl StreamingCompletionModel for CompletionModel { }, StreamingEvent::ToolCallEnd { .. } => { let Some(tc) = current_tool_call.clone() else { continue; }; - + let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; - + yield Ok(RawStreamingChoice::ToolCall( tc.0, tc.1, args )); - + current_tool_call = None; }, _ => {} diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index c4bcdef..ac68953 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -17,7 +17,6 @@ use crate::message::AssistantContent; use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; -use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; From ce37d44b15eeb24d161da708d629c54af32e6991 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Thu, 10 Apr 2025 20:01:43 -0400 Subject: [PATCH 07/16] chore: clippy again --- rig-core/src/providers/cohere/streaming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index c1e6938..6b9db9f 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -164,7 +164,7 @@ impl StreamingCompletionModel for CompletionModel { )); }; }, - StreamingEvent::ToolCallEnd { .. } => { + StreamingEvent::ToolCallEnd => { let Some(tc) = current_tool_call.clone() else { continue; }; let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; From 6dcabc1e0edd3568f5be9ad473f579179604f2d9 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Thu, 10 Apr 2025 20:21:54 -0400 Subject: [PATCH 08/16] chore: comment the new streaming code --- rig-core/src/streaming.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index ac68953..c2aacbc 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -30,7 +30,8 @@ pub enum RawStreamingChoice { /// A tool call response chunk ToolCall(String, String, serde_json::Value), - /// The final response object + /// The final response object, must be yielded if you want the + /// `response` field to be populated on the `StreamingCompletionResponse` FinalResponse(R), } @@ -42,11 +43,18 @@ pub type StreamingResult = pub type StreamingResult = Pin, CompletionError>>>>; +/// The response from a streaming completion request; +/// message and response are populated at the end of the +/// `inner` stream. pub struct StreamingCompletionResponse { inner: StreamingResult, text: String, tool_calls: Vec<(String, String, serde_json::Value)>, + /// The final aggregated message from the stream + /// contains all text and tool calls generated pub message: Message, + /// The final response from the stream, may be `None` + /// if the provider didn't yield it during the stream pub response: Option, } @@ -71,12 +79,15 @@ impl Stream for StreamingCompletionResponse { match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => { + // This is run at the end of the inner stream to collect all tokens into + // a single unified `Message`. let mut content = vec![]; stream.tool_calls.iter().for_each(|(n, d, a)| { content.push(AssistantContent::tool_call(n, d, a.clone())); }); + // This is required to ensure there's always at least one item in the content if content.is_empty() || !stream.text.is_empty() { content.insert(0, AssistantContent::text(stream.text.clone())); } @@ -91,16 +102,21 @@ impl Stream for StreamingCompletionResponse { Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { + // Forward the streaming tokens to the outer stream + // and concat the text together stream.text = format!("{}{}", stream.text, text.clone()); Poll::Ready(Some(Ok(AssistantContent::text(text)))) } RawStreamingChoice::ToolCall(id, name, args) => { + // Keep track of each tool call to aggregate the final message later + // and pass it to the outer stream stream .tool_calls .push((id.clone(), name.clone(), args.clone())); Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args)))) } RawStreamingChoice::FinalResponse(response) => { + // Set the final response field and return the next item in the stream stream.response = Some(response); stream.poll_next_unpin(cx) From cb4be5c6dabb593d9dc9d6744f9774f40f3ba17c Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Fri, 11 Apr 2025 09:41:14 -0400 Subject: [PATCH 09/16] fix: remove unnecessary tool call tuple --- rig-core/src/streaming.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index c2aacbc..603d483 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -13,7 +13,7 @@ use crate::agent::Agent; use crate::completion::{ CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, }; -use crate::message::AssistantContent; +use crate::message::{AssistantContent, ToolCall, ToolFunction}; use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; @@ -49,7 +49,7 @@ pub type StreamingResult = pub struct StreamingCompletionResponse { inner: StreamingResult, text: String, - tool_calls: Vec<(String, String, serde_json::Value)>, + tool_calls: Vec, /// The final aggregated message from the stream /// contains all text and tool calls generated pub message: Message, @@ -83,8 +83,8 @@ impl Stream for StreamingCompletionResponse { // a single unified `Message`. let mut content = vec![]; - stream.tool_calls.iter().for_each(|(n, d, a)| { - content.push(AssistantContent::tool_call(n, d, a.clone())); + stream.tool_calls.iter().for_each(|tc| { + content.push(AssistantContent::ToolCall(tc.clone())); }); // This is required to ensure there's always at least one item in the content @@ -110,9 +110,13 @@ impl Stream for StreamingCompletionResponse { RawStreamingChoice::ToolCall(id, name, args) => { // Keep track of each tool call to aggregate the final message later // and pass it to the outer stream - stream - .tool_calls - .push((id.clone(), name.clone(), args.clone())); + stream.tool_calls.push(ToolCall { + id: id.clone(), + function: ToolFunction { + name: name.clone(), + arguments: args.clone(), + }, + }); Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args)))) } RawStreamingChoice::FinalResponse(response) => { From 9689693468080cbe1cfd2ffce97818b10b9d7051 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 09:06:31 -0400 Subject: [PATCH 10/16] feat: change StreamingCompletionResponse to resemble CompletionResponse --- rig-core/src/streaming.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 603d483..100f699 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -11,7 +11,8 @@ use crate::agent::Agent; use crate::completion::{ - CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, + CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, + CompletionResponse, Message, }; use crate::message::{AssistantContent, ToolCall, ToolFunction}; use crate::OneOrMany; @@ -52,7 +53,7 @@ pub struct StreamingCompletionResponse { tool_calls: Vec, /// The final aggregated message from the stream /// contains all text and tool calls generated - pub message: Message, + pub choice: OneOrMany, /// The final response from the stream, may be `None` /// if the provider didn't yield it during the stream pub response: Option, @@ -64,12 +65,21 @@ impl StreamingCompletionResponse { inner, text: "".to_string(), tool_calls: vec![], - message: Message::assistant(""), + choice: OneOrMany::one(AssistantContent::text("")), response: None, } } } +impl Into>> for StreamingCompletionResponse { + fn into(self) -> CompletionResponse> { + CompletionResponse { + choice: self.choice, + raw_response: self.response, + } + } +} + impl Stream for StreamingCompletionResponse { type Item = Result; From bd4f5f27dafd0311e05e7e0958408f230e88e4fb Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 09:18:28 -0400 Subject: [PATCH 11/16] fix: field wasn't renamed --- rig-core/src/streaming.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 100f699..f379c28 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -91,21 +91,19 @@ impl Stream for StreamingCompletionResponse { Poll::Ready(None) => { // This is run at the end of the inner stream to collect all tokens into // a single unified `Message`. - let mut content = vec![]; + let mut choice = vec![]; stream.tool_calls.iter().for_each(|tc| { - content.push(AssistantContent::ToolCall(tc.clone())); + choice.push(AssistantContent::ToolCall(tc.clone())); }); // This is required to ensure there's always at least one item in the content - if content.is_empty() || !stream.text.is_empty() { - content.insert(0, AssistantContent::text(stream.text.clone())); + if choice.is_empty() || !stream.text.is_empty() { + choice.insert(0, AssistantContent::text(stream.text.clone())); } - stream.message = Message::Assistant { - content: OneOrMany::many(content) - .expect("There should be at least one assistant message"), - }; + stream.choice = OneOrMany::many(choice) + .expect("There should be at least one assistant message"); Poll::Ready(None) } From 651a113d5c2ca552dd72cf9c3ef03598b4cac5e6 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 09:29:38 -0400 Subject: [PATCH 12/16] chore: fmt --- rig-core/src/streaming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index f379c28..2d720a9 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -103,7 +103,7 @@ impl Stream for StreamingCompletionResponse { } stream.choice = OneOrMany::many(choice) - .expect("There should be at least one assistant message"); + .expect("There should be at least one assistant message"); Poll::Ready(None) } From c28210df902be211f299758a175ebc17d0bf9684 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 09:33:28 -0400 Subject: [PATCH 13/16] fix: Into -> From --- rig-core/src/streaming.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 2d720a9..7a58498 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -71,11 +71,11 @@ impl StreamingCompletionResponse { } } -impl Into>> for StreamingCompletionResponse { - fn into(self) -> CompletionResponse> { +impl From> for CompletionResponse> { + fn from(value: StreamingCompletionResponse) -> CompletionResponse> { CompletionResponse { - choice: self.choice, - raw_response: self.response, + choice: value.choice, + raw_response: value.response, } } } From b7c80c6d19fb1a459626d5cfb9639d9f5be80b00 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 09:38:03 -0400 Subject: [PATCH 14/16] fix: update examples --- rig-core/examples/anthropic_streaming.rs | 2 +- rig-core/examples/anthropic_streaming_with_tools.rs | 2 +- rig-core/examples/cohere_streaming.rs | 2 +- rig-core/examples/cohere_streaming_with_tools.rs | 2 +- rig-core/examples/gemini_streaming.rs | 2 +- rig-core/examples/gemini_streaming_with_tools.rs | 2 +- rig-core/examples/ollama_streaming.rs | 2 +- rig-core/examples/ollama_streaming_with_tools.rs | 2 +- rig-core/examples/openai_streaming.rs | 2 +- rig-core/examples/openai_streaming_with_tools.rs | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rig-core/examples/anthropic_streaming.rs b/rig-core/examples/anthropic_streaming.rs index b61ef96..fa0c66f 100644 --- a/rig-core/examples/anthropic_streaming.rs +++ b/rig-core/examples/anthropic_streaming.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.usage.output_tokens); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/anthropic_streaming_with_tools.rs b/rig-core/examples/anthropic_streaming_with_tools.rs index 9dc53b4..5914d10 100644 --- a/rig-core/examples/anthropic_streaming_with_tools.rs +++ b/rig-core/examples/anthropic_streaming_with_tools.rs @@ -112,7 +112,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.usage.output_tokens); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/cohere_streaming.rs b/rig-core/examples/cohere_streaming.rs index bf1f16c..d6fb6eb 100644 --- a/rig-core/examples/cohere_streaming.rs +++ b/rig-core/examples/cohere_streaming.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.usage); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/cohere_streaming_with_tools.rs b/rig-core/examples/cohere_streaming_with_tools.rs index c4d5961..53012d1 100644 --- a/rig-core/examples/cohere_streaming_with_tools.rs +++ b/rig-core/examples/cohere_streaming_with_tools.rs @@ -112,7 +112,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.usage); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/gemini_streaming.rs b/rig-core/examples/gemini_streaming.rs index fc22907..6fa34ae 100644 --- a/rig-core/examples/gemini_streaming.rs +++ b/rig-core/examples/gemini_streaming.rs @@ -26,6 +26,6 @@ async fn main() -> Result<(), anyhow::Error> { ); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/gemini_streaming_with_tools.rs b/rig-core/examples/gemini_streaming_with_tools.rs index 8c9bffe..ffdd135 100644 --- a/rig-core/examples/gemini_streaming_with_tools.rs +++ b/rig-core/examples/gemini_streaming_with_tools.rs @@ -115,7 +115,7 @@ async fn main() -> Result<(), anyhow::Error> { ); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/ollama_streaming.rs b/rig-core/examples/ollama_streaming.rs index 1bc1e33..9c745f5 100644 --- a/rig-core/examples/ollama_streaming.rs +++ b/rig-core/examples/ollama_streaming.rs @@ -21,6 +21,6 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.eval_count); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/ollama_streaming_with_tools.rs b/rig-core/examples/ollama_streaming_with_tools.rs index 37b5854..d61f16e 100644 --- a/rig-core/examples/ollama_streaming_with_tools.rs +++ b/rig-core/examples/ollama_streaming_with_tools.rs @@ -112,7 +112,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?} tokens", response.eval_count); }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index 40ec569..87772da 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?}", response.usage) }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/openai_streaming_with_tools.rs b/rig-core/examples/openai_streaming_with_tools.rs index b83a1fa..6d57855 100644 --- a/rig-core/examples/openai_streaming_with_tools.rs +++ b/rig-core/examples/openai_streaming_with_tools.rs @@ -112,7 +112,7 @@ async fn main() -> Result<(), anyhow::Error> { println!("Usage: {:?}", response.usage) }; - println!("Message: {:?}", stream.message); + println!("Message: {:?}", stream.choice); Ok(()) } From c432a6bdcdceb6f1d099a10eda5172b324a4e5b7 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 16:15:48 -0400 Subject: [PATCH 15/16] feat: openrouter streaming --- rig-core/src/providers/openai/streaming.rs | 34 +- rig-core/src/providers/openrouter/client.rs | 125 +++++++ .../completion.rs} | 180 ++-------- rig-core/src/providers/openrouter/mod.rs | 17 + .../src/providers/openrouter/streaming.rs | 313 ++++++++++++++++++ 5 files changed, 506 insertions(+), 163 deletions(-) create mode 100644 rig-core/src/providers/openrouter/client.rs rename rig-core/src/providers/{openrouter.rs => openrouter/completion.rs} (57%) create mode 100644 rig-core/src/providers/openrouter/mod.rs create mode 100644 rig-core/src/providers/openrouter/streaming.rs diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index c3e83ac..ad6ac41 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -11,6 +11,7 @@ use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; +use tracing::debug; // ================================================================ // OpenAI Completion Streaming API @@ -26,10 +27,11 @@ pub struct StreamingFunction { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct StreamingToolCall { pub index: usize, + pub id: Option, pub function: StreamingFunction, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct StreamingDelta { #[serde(default)] content: Option, @@ -37,12 +39,12 @@ struct StreamingDelta { tool_calls: Vec, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct StreamingChoice { delta: StreamingDelta, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct StreamingCompletionChunk { choices: Vec, usage: Option, @@ -94,7 +96,7 @@ pub async fn send_compatible_streaming_request( }; let mut partial_data = None; - let mut calls: HashMap = HashMap::new(); + let mut calls: HashMap = HashMap::new(); while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { @@ -139,6 +141,8 @@ pub async fn send_compatible_streaming_request( let data = serde_json::from_str::(&line); let Ok(data) = data else { + let err = data.unwrap_err(); + debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err); continue; }; @@ -150,35 +154,39 @@ pub async fn send_compatible_streaming_request( if !delta.tool_calls.is_empty() { for tool_call in &delta.tool_calls { let function = tool_call.function.clone(); - // Start of tool call // name: Some(String) // arguments: None if function.name.is_some() && function.arguments.is_empty() { - calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); + let id = tool_call.id.clone().unwrap_or("".to_string()); + + calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string())); } // Part of tool call // name: None // arguments: Some(String) else if function.name.is_none() && !function.arguments.is_empty() { - let Some((name, arguments)) = calls.get(&tool_call.index) else { + let Some((id, name, arguments)) = calls.get(&tool_call.index) else { + debug!("Partial tool call received but tool call was never started."); continue; }; let new_arguments = &tool_call.function.arguments; let arguments = format!("{}{}", arguments, new_arguments); - calls.insert(tool_call.index, (name.clone(), arguments)); + calls.insert(tool_call.index, (id.clone(), name.clone(), arguments)); } // Entire tool call else { - let name = function.name.unwrap(); + let id = tool_call.id.clone().unwrap_or("".to_string()); + let name = function.name.expect("function name should be present for complete tool call"); let arguments = function.arguments; let Ok(arguments) = serde_json::from_str(&arguments) else { + debug!("Couldn't serialize '{}' as a json value", arguments); continue; }; - yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments)) } } } @@ -195,12 +203,14 @@ pub async fn send_compatible_streaming_request( } } - for (_, (name, arguments)) in calls { + for (_, (id, name, arguments)) in calls { let Ok(arguments) = serde_json::from_str(&arguments) else { continue; }; - yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + println!("{id} {name}"); + + yield Ok(RawStreamingChoice::ToolCall(id, name, arguments)) } yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs new file mode 100644 index 0000000..2eab6c1 --- /dev/null +++ b/rig-core/src/providers/openrouter/client.rs @@ -0,0 +1,125 @@ +use crate::{agent::AgentBuilder, extractor::ExtractorBuilder}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::completion::CompletionModel; + +// ================================================================ +// Main openrouter Client +// ================================================================ +const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new OpenRouter client with the given API key. + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, OPENROUTER_API_BASE_URL) + } + + /// Create a new OpenRouter client with the given API key and base API URL. + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("OpenRouter reqwest client should build"), + } + } + + /// Create a new openrouter client from the `openrouter_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set"); + Self::new(&api_key) + } + + pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::openrouter::{Client, self}; + /// + /// // Initialize the openrouter client + /// let openrouter = Client::new("your-openrouter-api-key"); + /// + /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B); + /// ``` + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::openrouter::{Client, self}; + /// + /// // Initialize the Eternal client + /// let openrouter = Client::new("your-openrouter-api-key"); + /// + /// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub(crate) message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} diff --git a/rig-core/src/providers/openrouter.rs b/rig-core/src/providers/openrouter/completion.rs similarity index 57% rename from rig-core/src/providers/openrouter.rs rename to rig-core/src/providers/openrouter/completion.rs index 149b2ff..85b9bfe 100644 --- a/rig-core/src/providers/openrouter.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -1,147 +1,16 @@ -//! OpenRouter Inference API client and Rig integration -//! -//! # Example -//! ``` -//! use rig::providers::openrouter; -//! -//! let client = openrouter::Client::new("YOUR_API_KEY"); -//! -//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B); -//! ``` +use serde::Deserialize; + +use super::client::{ApiErrorResponse, ApiResponse, Client, Usage}; use crate::{ - agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - extractor::ExtractorBuilder, json_utils, providers::openai::Message, OneOrMany, }; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; -use super::openai::AssistantContent; - -// ================================================================ -// Main openrouter Client -// ================================================================ -const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; - -#[derive(Clone)] -pub struct Client { - base_url: String, - http_client: reqwest::Client, -} - -impl Client { - /// Create a new OpenRouter client with the given API key. - pub fn new(api_key: &str) -> Self { - Self::from_url(api_key, OPENROUTER_API_BASE_URL) - } - - /// Create a new OpenRouter client with the given API key and base API URL. - pub fn from_url(api_key: &str, base_url: &str) -> Self { - Self { - base_url: base_url.to_string(), - http_client: reqwest::Client::builder() - .default_headers({ - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "Authorization", - format!("Bearer {}", api_key) - .parse() - .expect("Bearer token should parse"), - ); - headers - }) - .build() - .expect("OpenRouter reqwest client should build"), - } - } - - /// Create a new openrouter client from the `openrouter_API_KEY` environment variable. - /// Panics if the environment variable is not set. - pub fn from_env() -> Self { - let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set"); - Self::new(&api_key) - } - - fn post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url) - } - - /// Create a completion model with the given name. - /// - /// # Example - /// ``` - /// use rig::providers::openrouter::{Client, self}; - /// - /// // Initialize the openrouter client - /// let openrouter = Client::new("your-openrouter-api-key"); - /// - /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B); - /// ``` - pub fn completion_model(&self, model: &str) -> CompletionModel { - CompletionModel::new(self.clone(), model) - } - - /// Create an agent builder with the given completion model. - /// - /// # Example - /// ``` - /// use rig::providers::openrouter::{Client, self}; - /// - /// // Initialize the Eternal client - /// let openrouter = Client::new("your-openrouter-api-key"); - /// - /// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B) - /// .preamble("You are comedian AI with a mission to make people laugh.") - /// .temperature(0.0) - /// .build(); - /// ``` - pub fn agent(&self, model: &str) -> AgentBuilder { - AgentBuilder::new(self.completion_model(model)) - } - - /// Create an extractor builder with the given completion model. - pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( - &self, - model: &str, - ) -> ExtractorBuilder { - ExtractorBuilder::new(self.completion_model(model)) - } -} - -#[derive(Debug, Deserialize)] -struct ApiErrorResponse { - message: String, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -enum ApiResponse { - Ok(T), - Err(ApiErrorResponse), -} - -#[derive(Clone, Debug, Deserialize)] -pub struct Usage { - pub prompt_tokens: usize, - pub completion_tokens: usize, - pub total_tokens: usize, -} - -impl std::fmt::Display for Usage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Prompt tokens: {} Total tokens: {}", - self.prompt_tokens, self.total_tokens - ) - } -} +use crate::providers::openai::AssistantContent; // ================================================================ // OpenRouter Completion API @@ -241,7 +110,7 @@ pub struct Choice { #[derive(Clone)] pub struct CompletionModel { - client: Client, + pub(crate) client: Client, /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) pub model: String, } @@ -253,16 +122,11 @@ impl CompletionModel { model: model.to_string(), } } -} -impl completion::CompletionModel for CompletionModel { - type Response = CompletionResponse; - - #[cfg_attr(feature = "worker", worker::send)] - async fn completion( + pub(crate) fn create_completion_request( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result { // Add preamble to chat history (if available) let mut full_history: Vec = match &completion_request.preamble { Some(preamble) => vec![Message::system(preamble)], @@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel { "temperature": completion_request.temperature, }); + let request = if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }; + + Ok(request) + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let request = self.create_completion_request(completion_request)?; + let response = self .client .post("/chat/completions") - .json( - &if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }, - ) + .json(&request) .send() .await?; diff --git a/rig-core/src/providers/openrouter/mod.rs b/rig-core/src/providers/openrouter/mod.rs new file mode 100644 index 0000000..9983815 --- /dev/null +++ b/rig-core/src/providers/openrouter/mod.rs @@ -0,0 +1,17 @@ +//! OpenRouter Inference API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::openrouter; +//! +//! let client = openrouter::Client::new("YOUR_API_KEY"); +//! +//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B); +//! ``` + +pub mod client; +pub mod completion; +pub mod streaming; + +pub use client::*; +pub use completion::*; diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs new file mode 100644 index 0000000..1704dcb --- /dev/null +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -0,0 +1,313 @@ +use std::collections::HashMap; + +use crate::{ + json_utils, + message::{ToolCall, ToolFunction}, + streaming::{self}, +}; +use async_stream::stream; +use futures::StreamExt; +use reqwest::RequestBuilder; +use serde_json::{json, Value}; + +use crate::{ + completion::{CompletionError, CompletionRequest}, + streaming::StreamingCompletionModel, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamingCompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamingChoice { + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub native_finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub index: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MessageResponse { + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(default)] + pub tool_calls: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OpenRouterToolFunction { + pub name: Option, + pub arguments: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OpenRouterToolCall { + pub index: usize, + pub id: Option, + pub r#type: Option, + pub function: OpenRouterToolFunction, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct ResponseUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ErrorResponse { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct DeltaResponse { + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(default)] + pub tool_calls: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub native_finish_reason: Option, +} + +#[derive(Clone)] +pub struct FinalCompletionResponse { + pub usage: ResponseUsage, +} + +impl StreamingCompletionModel for super::CompletionModel { + type StreamingResponse = FinalCompletionResponse; + + async fn stream( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> + { + let request = self.create_completion_request(completion_request)?; + + let request = json_utils::merge(request, json!({"stream": true})); + + let builder = self.client.post("/chat/completions").json(&request); + + send_streaming_request(builder).await + } +} + +pub async fn send_streaming_request( + request_builder: RequestBuilder, +) -> Result, CompletionError> { + let response = request_builder.send().await?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError(format!( + "{}: {}", + response.status(), + response.text().await? + ))); + } + + // Handle OpenAI Compatible SSE chunks + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + let mut tool_calls = HashMap::new(); + let mut partial_line = String::new(); + let mut final_usage = None; + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(CompletionError::from(e)); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + yield Err(CompletionError::ResponseError(e.to_string())); + break; + } + }; + + for line in text.lines() { + let mut line = line.to_string(); + + // Skip empty lines and processing messages, as well as [DONE] (might be useful though) + if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" { + continue; + } + + // Handle data: prefix + line = line.strip_prefix("data: ").unwrap_or(&line).to_string(); + + // If line starts with { but doesn't end with }, it's a partial JSON + if line.starts_with('{') && !line.ends_with('}') { + partial_line = line; + continue; + } + + // If we have a partial line and this line ends with }, complete it + if !partial_line.is_empty() { + if line.ends_with('}') { + partial_line.push_str(&line); + line = partial_line; + partial_line = String::new(); + } else { + partial_line.push_str(&line); + continue; + } + } + + let data = match serde_json::from_str::(&line) { + Ok(data) => data, + Err(_) => { + continue; + } + }; + + + let choice = data.choices.first().expect("Should have at least one choice"); + + // TODO this has to handle outputs like this: + // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}] + // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}] + if let Some(delta) = &choice.delta { + if !delta.tool_calls.is_empty() { + for tool_call in &delta.tool_calls { + let index = tool_call.index; + + // Get or create tool call entry + let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall { + id: String::new(), + function: ToolFunction { + name: String::new(), + arguments: serde_json::Value::Null, + }, + }); + + // Update fields if present + if let Some(id) = &tool_call.id { + if !id.is_empty() { + existing_tool_call.id = id.clone(); + } + } + if let Some(name) = &tool_call.function.name { + if !name.is_empty() { + existing_tool_call.function.name = name.clone(); + } + } + if let Some(chunk) = &tool_call.function.arguments { + // Convert current arguments to string if needed + let current_args = match &existing_tool_call.function.arguments { + serde_json::Value::Null => String::new(), + serde_json::Value::String(s) => s.clone(), + v => v.to_string(), + }; + + // Concatenate the new chunk + let combined = format!("{}{}", current_args, chunk); + + // Try to parse as JSON if it looks complete + if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { + match serde_json::from_str(&combined) { + Ok(parsed) => existing_tool_call.function.arguments = parsed, + Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined), + } + } else { + existing_tool_call.function.arguments = serde_json::Value::String(combined); + } + } + } + } + + if let Some(content) = &delta.content { + if !content.is_empty() { + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + } + } + + if let Some(usage) = data.usage { + final_usage = Some(usage); + } + } + + // Handle message format + if let Some(message) = &choice.message { + if !message.tool_calls.is_empty() { + for tool_call in &message.tool_calls { + let name = tool_call.function.name.clone(); + let id = tool_call.id.clone(); + let arguments = if let Some(args) = &tool_call.function.arguments { + // Try to parse the string as JSON, fallback to string value + match serde_json::from_str(args) { + Ok(v) => v, + Err(_) => serde_json::Value::String(args.to_string()), + } + } else { + serde_json::Value::Null + }; + let index = tool_call.index; + + tool_calls.insert(index, ToolCall{ + id: id.unwrap_or_default(), + function: ToolFunction { + name: name.unwrap_or_default(), + arguments, + }, + }); + } + } + + if !message.content.is_empty() { + yield Ok(streaming::RawStreamingChoice::Message(message.content.clone())) + } + } + } + } + + for (_, tool_call) in tool_calls.into_iter() { + + yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments)); + } + + yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse { + usage: final_usage.unwrap_or_default() + })) + + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) +} From a78969fd9a5e2d5a49dfac9f0482eb879fdb11fc Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Tue, 15 Apr 2025 19:01:19 -0400 Subject: [PATCH 16/16] chore: add openrouter example Co-authored-by: Mochan <0xMochan@users.noreply.github.com> --- .../openrouter_streaming_with_tools.rs | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 rig-core/examples/openrouter_streaming_with_tools.rs diff --git a/rig-core/examples/openrouter_streaming_with_tools.rs b/rig-core/examples/openrouter_streaming_with_tools.rs new file mode 100644 index 0000000..96a2256 --- /dev/null +++ b/rig-core/examples/openrouter_streaming_with_tools.rs @@ -0,0 +1,118 @@ +use anyhow::Result; +use rig::streaming::stream_to_stdout; +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"] + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let calculator_agent = providers::openrouter::Client::from_env() + .agent(providers::openrouter::GEMINI_FLASH_2_0) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + println!("Calculate 2 - 5"); + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; + stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.choice); + + Ok(()) +} \ No newline at end of file