From 2920eb0a0e91f06034e5678aa8331f546d34d998 Mon Sep 17 00:00:00 2001 From: yavens <179155341+yavens@users.github.noreply.github.com> Date: Wed, 9 Apr 2025 15:33:26 -0400 Subject: [PATCH] feat: update examples --- rig-core/examples/anthropic_streaming.rs | 8 ++ .../anthropic_streaming_with_tools.rs | 7 ++ rig-core/examples/gemini_streaming.rs | 5 + .../examples/gemini_streaming_with_tools.rs | 7 ++ rig-core/examples/ollama_streaming.rs | 5 + .../examples/ollama_streaming_with_tools.rs | 7 ++ rig-core/examples/openai_streaming.rs | 2 + .../examples/openai_streaming_with_tools.rs | 7 ++ rig-core/src/providers/gemini/completion.rs | 2 +- rig-core/src/providers/gemini/streaming.rs | 14 ++- rig-core/src/providers/ollama.rs | 12 +++ rig-core/src/providers/openai/streaming.rs | 96 ++++++++++--------- rig-core/src/streaming.rs | 17 ++-- 13 files changed, 134 insertions(+), 55 deletions(-) diff --git a/rig-core/examples/anthropic_streaming.rs b/rig-core/examples/anthropic_streaming.rs index 349a45d..015189a 100644 --- a/rig-core/examples/anthropic_streaming.rs +++ b/rig-core/examples/anthropic_streaming.rs @@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.message); + + Ok(()) } diff --git a/rig-core/examples/anthropic_streaming_with_tools.rs b/rig-core/examples/anthropic_streaming_with_tools.rs index ec3ee7c..9dc53b4 100644 --- a/rig-core/examples/anthropic_streaming_with_tools.rs +++ b/rig-core/examples/anthropic_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/gemini_streaming.rs b/rig-core/examples/gemini_streaming.rs index 1ff711b..34f57bd 100644 --- a/rig-core/examples/gemini_streaming.rs +++ b/rig-core/examples/gemini_streaming.rs @@ -19,5 +19,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/gemini_streaming_with_tools.rs b/rig-core/examples/gemini_streaming_with_tools.rs index 43f469d..a5fb1c9 100644 --- a/rig-core/examples/gemini_streaming_with_tools.rs +++ b/rig-core/examples/gemini_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage_metadata.total_token_count); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/ollama_streaming.rs b/rig-core/examples/ollama_streaming.rs index fe12467..1bc1e33 100644 --- a/rig-core/examples/ollama_streaming.rs +++ b/rig-core/examples/ollama_streaming.rs @@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/ollama_streaming_with_tools.rs b/rig-core/examples/ollama_streaming_with_tools.rs index 0e59549..4b4427f 100644 --- a/rig-core/examples/ollama_streaming_with_tools.rs +++ b/rig-core/examples/ollama_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index a474702..f668036 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -20,6 +20,8 @@ async fn main() -> Result<(), anyhow::Error> { if let Some(response) = stream.response { println!("Usage: {:?}", response.usage) }; + + println!("Message: {:?}", stream.message); Ok(()) } diff --git a/rig-core/examples/openai_streaming_with_tools.rs b/rig-core/examples/openai_streaming_with_tools.rs index 997bebb..72c4aeb 100644 --- a/rig-core/examples/openai_streaming_with_tools.rs +++ b/rig-core/examples/openai_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.message); + Ok(()) } diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index ccfdeff..7491fc5 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -609,7 +609,7 @@ pub mod gemini_api_types { HarmCategoryCivicIntegrity, } - #[derive(Debug, Deserialize, Clone)] + #[derive(Debug, Deserialize, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 8e004af..66a6a82 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -9,18 +9,24 @@ use crate::{ streaming::{self, StreamingCompletionModel}, }; +#[derive(Debug, Deserialize, Default, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PartialUsage { + pub total_token_count: i32, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamGenerateContentResponse { /// Candidate responses from the model. pub candidates: Vec, pub model_version: Option, - pub usage_metadata: UsageMetadata, + pub usage_metadata: Option, } #[derive(Clone)] pub struct StreamingCompletionResponse { - pub usage_metadata: UsageMetadata, + pub usage_metadata: PartialUsage, } impl StreamingCompletionModel for CompletionModel { @@ -90,7 +96,9 @@ impl StreamingCompletionModel for CompletionModel { if choice.finish_reason.is_some() { yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage_metadata: data.usage_metadata, + usage_metadata: PartialUsage { + total_token_count: data.usage_metadata.unwrap().total_token_count, + } })) } } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 4c93475..a754fe2 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -486,6 +486,18 @@ impl StreamingCompletionModel for CompletionModel { continue; } } + + if response.done { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + total_duration: response.total_duration, + load_duration: response.load_duration, + prompt_eval_count: response.prompt_eval_count, + prompt_eval_duration: response.prompt_eval_duration, + eval_count: response.eval_count, + eval_duration: response.eval_duration, + done_reason: response.done_reason, + })); + } } } }); diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index 015bc8a..c42f1a4 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -46,12 +46,11 @@ struct StreamingChoice { struct StreamingCompletionChunk { choices: Vec, usage: Option, - finish_reason: Option, } #[derive(Clone)] pub struct StreamingCompletionResponse { - pub usage: Option, + pub usage: Usage, } impl StreamingCompletionModel for CompletionModel { @@ -62,7 +61,10 @@ impl StreamingCompletionModel for CompletionModel { ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await @@ -86,6 +88,11 @@ pub async fn send_compatible_streaming_request( let inner = Box::pin(stream! { let mut stream = response.bytes_stream(); + let mut final_usage = Usage { + prompt_tokens: 0, + total_tokens: 0 + }; + let mut partial_data = None; let mut calls: HashMap = HashMap::new(); @@ -110,8 +117,6 @@ pub async fn send_compatible_streaming_request( for line in text.lines() { let mut line = line.to_string(); - - // If there was a remaining part, concat with current line if partial_data.is_some() { line = format!("{}{}", partial_data.unwrap(), line); @@ -137,56 +142,53 @@ pub async fn send_compatible_streaming_request( continue; }; - let choice = data.choices.first().expect("Should have at least one choice"); - let delta = &choice.delta; + if let Some(choice) = data.choices.first() { - if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let function = tool_call.function.clone(); + let delta = &choice.delta; - // Start of tool call - // name: Some(String) - // arguments: None - if function.name.is_some() && function.arguments.is_empty() { - calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); - } - // Part of tool call - // name: None - // arguments: Some(String) - else if function.name.is_none() && !function.arguments.is_empty() { - let Some((name, arguments)) = calls.get(&tool_call.index) else { - continue; - }; + if !delta.tool_calls.is_empty() { + for tool_call in &delta.tool_calls { + let function = tool_call.function.clone(); - let new_arguments = &tool_call.function.arguments; - let arguments = format!("{}{}", arguments, new_arguments); + // Start of tool call + // name: Some(String) + // arguments: None + if function.name.is_some() && function.arguments.is_empty() { + calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); + } + // Part of tool call + // name: None + // arguments: Some(String) + else if function.name.is_none() && !function.arguments.is_empty() { + let Some((name, arguments)) = calls.get(&tool_call.index) else { + continue; + }; - calls.insert(tool_call.index, (name.clone(), arguments)); - } - // Entire tool call - else { - let name = function.name.unwrap(); - let arguments = function.arguments; - let Ok(arguments) = serde_json::from_str(&arguments) else { - continue; - }; + let new_arguments = &tool_call.function.arguments; + let arguments = format!("{}{}", arguments, new_arguments); - yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + calls.insert(tool_call.index, (name.clone(), arguments)); + } + // Entire tool call + else { + let name = function.name.unwrap(); + let arguments = function.arguments; + let Ok(arguments) = serde_json::from_str(&arguments) else { + continue; + }; + + yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + } } } + + } - if let Some(content) = &choice.delta.content { - yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + if let Some(usage) = data.usage { + final_usage = usage.clone(); } - - if data.finish_reason.is_some() { - yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage: data.usage - })) - } - } } @@ -195,8 +197,12 @@ pub async fn send_compatible_streaming_request( continue; }; - yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) + yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) } + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: final_usage.clone() + })) }); Ok(streaming::StreamingCompletionResponse::new(inner)) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 4d61116..5edc2c3 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -91,21 +91,25 @@ impl Stream for StreamingCompletionResponse { match stream.inner.as_mut().poll_next(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(None) => { - let content = vec![AssistantContent::text(stream.text.clone())]; + + let mut content = vec![]; stream.tool_calls.iter().for_each(|(n, d, a)| { - AssistantContent::tool_call(n, d, a.clone()); + content.push(AssistantContent::tool_call(n, d, a.clone())); }); + if content.len() == 0 || stream.text.len() > 0 { + content.insert(0, AssistantContent::text(stream.text.clone())); + } + stream.message = Message::Assistant { content: OneOrMany::many(content) .expect("There should be at least one assistant message"), }; - + Poll::Ready(None) - } + }, Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(choice))) => match choice { RawStreamingChoice::Message(text) => { @@ -120,7 +124,8 @@ impl Stream for StreamingCompletionResponse { } RawStreamingChoice::FinalResponse(response) => { stream.response = Some(response); - Poll::Pending + + stream.poll_next_unpin(cx) } }, }