diff --git a/rig-core/examples/anthropic_streaming.rs b/rig-core/examples/anthropic_streaming.rs index 349a45d..fa0c66f 100644 --- a/rig-core/examples/anthropic_streaming.rs +++ b/rig-core/examples/anthropic_streaming.rs @@ -19,5 +19,11 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/anthropic_streaming_with_tools.rs b/rig-core/examples/anthropic_streaming_with_tools.rs index ec3ee7c..5914d10 100644 --- a/rig-core/examples/anthropic_streaming_with_tools.rs +++ b/rig-core/examples/anthropic_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage.output_tokens); + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/cohere_streaming.rs b/rig-core/examples/cohere_streaming.rs new file mode 100644 index 0000000..d6fb6eb --- /dev/null +++ b/rig-core/examples/cohere_streaming.rs @@ -0,0 +1,27 @@ +use rig::providers::cohere; +use rig::streaming::{stream_to_stdout, StreamingPrompt}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create streaming agent with a single context prompt + let agent = cohere::Client::from_env() + .agent(cohere::COMMAND) + .preamble("Be precise and concise.") + .temperature(0.5) + .build(); + + // Stream the response and print chunks as they arrive + let mut stream = agent + .stream_prompt("When and where and what type is the next solar eclipse?") + .await?; + + stream_to_stdout(agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage); + }; + + println!("Message: {:?}", stream.choice); + + Ok(()) +} diff --git a/rig-core/examples/cohere_streaming_with_tools.rs b/rig-core/examples/cohere_streaming_with_tools.rs new file mode 100644 index 0000000..53012d1 --- /dev/null +++ b/rig-core/examples/cohere_streaming_with_tools.rs @@ -0,0 +1,118 @@ +use anyhow::Result; +use rig::streaming::stream_to_stdout; +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"] + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let calculator_agent = providers::cohere::Client::from_env() + .agent(providers::cohere::COMMAND_R) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + println!("Calculate 2 - 5"); + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; + stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.usage); + }; + + println!("Message: {:?}", stream.choice); + + Ok(()) +} diff --git a/rig-core/examples/gemini_streaming.rs b/rig-core/examples/gemini_streaming.rs index 1ff711b..6fa34ae 100644 --- a/rig-core/examples/gemini_streaming.rs +++ b/rig-core/examples/gemini_streaming.rs @@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!( + "Usage: {:?} tokens", + response.usage_metadata.total_token_count + ); + }; + + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/gemini_streaming_with_tools.rs b/rig-core/examples/gemini_streaming_with_tools.rs index 43f469d..ffdd135 100644 --- a/rig-core/examples/gemini_streaming_with_tools.rs +++ b/rig-core/examples/gemini_streaming_with_tools.rs @@ -107,5 +107,15 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!( + "Usage: {:?} tokens", + response.usage_metadata.total_token_count + ); + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/ollama_streaming.rs b/rig-core/examples/ollama_streaming.rs index fe12467..9c745f5 100644 --- a/rig-core/examples/ollama_streaming.rs +++ b/rig-core/examples/ollama_streaming.rs @@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.choice); Ok(()) } diff --git a/rig-core/examples/ollama_streaming_with_tools.rs b/rig-core/examples/ollama_streaming_with_tools.rs index 0e59549..d61f16e 100644 --- a/rig-core/examples/ollama_streaming_with_tools.rs +++ b/rig-core/examples/ollama_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?} tokens", response.eval_count); + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/openai_streaming.rs b/rig-core/examples/openai_streaming.rs index d4aadf0..87772da 100644 --- a/rig-core/examples/openai_streaming.rs +++ b/rig-core/examples/openai_streaming.rs @@ -17,5 +17,11 @@ async fn main() -> Result<(), anyhow::Error> { stream_to_stdout(agent, &mut stream).await?; + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/openai_streaming_with_tools.rs b/rig-core/examples/openai_streaming_with_tools.rs index 997bebb..6d57855 100644 --- a/rig-core/examples/openai_streaming_with_tools.rs +++ b/rig-core/examples/openai_streaming_with_tools.rs @@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> { println!("Calculate 2 - 5"); let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.choice); + Ok(()) } diff --git a/rig-core/examples/openrouter_streaming_with_tools.rs b/rig-core/examples/openrouter_streaming_with_tools.rs new file mode 100644 index 0000000..96a2256 --- /dev/null +++ b/rig-core/examples/openrouter_streaming_with_tools.rs @@ -0,0 +1,118 @@ +use anyhow::Result; +use rig::streaming::stream_to_stdout; +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"] + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let calculator_agent = providers::openrouter::Client::from_env() + .agent(providers::openrouter::GEMINI_FLASH_2_0) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + println!("Calculate 2 - 5"); + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; + stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.choice); + + Ok(()) +} \ No newline at end of file diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 781e2a8..bf5ab0d 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -110,23 +110,20 @@ use std::collections::HashMap; use futures::{stream, StreamExt, TryStreamExt}; +use crate::streaming::StreamingCompletionResponse; +#[cfg(feature = "mcp")] +use crate::tool::McpTool; use crate::{ completion::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Message, Prompt, PromptError, }, message::AssistantContent, - streaming::{ - StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, - StreamingResult, - }, + streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt}, tool::{Tool, ToolSet}, vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; -#[cfg(feature = "mcp")] -use crate::tool::McpTool; - /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// (i.e.: system prompt) and a static set of context documents and tools. /// All context documents and tools are always provided to the agent when prompted. @@ -500,18 +497,21 @@ impl StreamingCompletion for Agent { } } -impl StreamingPrompt for Agent { - async fn stream_prompt(&self, prompt: &str) -> Result { +impl StreamingPrompt for Agent { + async fn stream_prompt( + &self, + prompt: &str, + ) -> Result, CompletionError> { self.stream_chat(prompt, vec![]).await } } -impl StreamingChat for Agent { +impl StreamingChat for Agent { async fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> Result { + ) -> Result, CompletionError> { self.stream_completion(prompt, chat_history) .await? .stream() diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..07e8f64 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -67,7 +67,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::OneOrMany; use crate::{ json_utils, @@ -467,7 +467,9 @@ impl CompletionRequestBuilder { impl CompletionRequestBuilder { /// Stream the completion request - pub async fn stream(self) -> Result { + pub async fn stream( + self, + ) -> Result, CompletionError> { let model = self.model.clone(); model.stream(self.build()).await } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index b351515..c343b44 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -8,7 +8,8 @@ use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge_inplace; use crate::message::MessageError; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -61,7 +62,7 @@ pub struct MessageDelta { pub stop_sequence: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct PartialUsage { pub output_tokens: usize, #[serde(default)] @@ -75,11 +76,18 @@ struct ToolCallState { input_json: String, } +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: PartialUsage, +} + impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let max_tokens = if let Some(tokens) = completion_request.max_tokens { tokens } else if let Some(tokens) = self.default_max_tokens { @@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel { // Use our SSE decoder to directly handle Server-Sent Events format let sse_stream = sse_from_response(response); - Ok(Box::pin(stream! { + let stream: StreamingResult = Box::pin(stream! { let mut current_tool_call: Option = None; let mut sse_stream = Box::pin(sse_stream); + let mut input_tokens = 0; while let Some(sse_result) = sse_stream.next().await { match sse_result { @@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel { // Parse the SSE data as a StreamingEvent match serde_json::from_str::(&sse.data) { Ok(event) => { + match &event { + StreamingEvent::MessageStart { message } => { + input_tokens = message.usage.input_tokens; + }, + StreamingEvent::MessageDelta { delta, usage } => { + if delta.stop_reason.is_some() { + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: PartialUsage { + output_tokens: usage.output_tokens, + input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), + } + })) + } + } + _ => {} + } + if let Some(result) = handle_event(&event, &mut current_tool_call) { yield result; } @@ -184,19 +211,21 @@ impl StreamingCompletionModel for CompletionModel { } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } fn handle_event( event: &StreamingEvent, current_tool_call: &mut Option, -) -> Option> { +) -> Option, CompletionError>> { match event { StreamingEvent::ContentBlockDelta { delta, .. } => match delta { ContentDelta::TextDelta { text } => { if current_tool_call.is_none() { - return Some(Ok(StreamingChoice::Message(text.clone()))); + return Some(Ok(RawStreamingChoice::Message(text.clone()))); } None } @@ -227,7 +256,7 @@ fn handle_event( &tool_call.input_json }; match serde_json::from_str(json_str) { - Ok(json_value) => Some(Ok(StreamingChoice::ToolCall( + Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall( tool_call.name, tool_call.id, json_value, diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index c2e0ec9..663464f 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel { // Azure OpenAI Streaming API // ----------------------------------------------------- impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type StreamingResponse = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self .client diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index 9621bf7..ca41692 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -6,8 +6,9 @@ use crate::{ }; use super::client::Client; +use crate::completion::CompletionRequest; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; #[derive(Debug, Deserialize)] pub struct CompletionResponse { @@ -419,7 +420,7 @@ impl TryFrom for message::Message { #[derive(Clone)] pub struct CompletionModel { - client: Client, + pub(crate) client: Client, pub model: String, } @@ -430,16 +431,11 @@ impl CompletionModel { model: model.to_string(), } } -} -impl completion::CompletionModel for CompletionModel { - type Response = CompletionResponse; - - #[cfg_attr(feature = "worker", worker::send)] - async fn completion( + pub(crate) fn create_completion_request( &self, - completion_request: completion::CompletionRequest, - ) -> Result, CompletionError> { + completion_request: CompletionRequest, + ) -> Result { let prompt = completion_request.prompt_with_context(); let mut messages: Vec = @@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel { "tools": completion_request.tools.into_iter().map(Tool::from).collect::>(), }); + if let Some(ref params) = completion_request.additional_params { + Ok(json_utils::merge(request.clone(), params.clone())) + } else { + Ok(request) + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let request = self.create_completion_request(completion_request)?; tracing::debug!( "Cohere request: {}", serde_json::to_string_pretty(&request)? ); - let response = self - .client - .post("/v2/chat") - .json( - &if let Some(ref params) = completion_request.additional_params { - json_utils::merge(request.clone(), params.clone()) - } else { - request.clone() - }, - ) - .send() - .await?; + let response = self.client.post("/v2/chat").json(&request).send().await?; if response.status().is_success() { let text_response = response.text().await?; diff --git a/rig-core/src/providers/cohere/mod.rs b/rig-core/src/providers/cohere/mod.rs index 4c290c0..b88dc32 100644 --- a/rig-core/src/providers/cohere/mod.rs +++ b/rig-core/src/providers/cohere/mod.rs @@ -12,6 +12,7 @@ pub mod client; pub mod completion; pub mod embeddings; +pub mod streaming; pub use client::Client; pub use client::{ApiErrorResponse, ApiResponse}; @@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel; // ================================================================ /// `command-r-plus` completion model -pub const COMMAND_R_PLUS: &str = "comman-r-plus"; +pub const COMMAND_R_PLUS: &str = "command-r-plus"; /// `command-r` completion model pub const COMMAND_R: &str = "command-r"; /// `command` completion model diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs new file mode 100644 index 0000000..6b9db9f --- /dev/null +++ b/rig-core/src/providers/cohere/streaming.rs @@ -0,0 +1,188 @@ +use crate::completion::{CompletionError, CompletionRequest}; +use crate::providers::cohere::completion::Usage; +use crate::providers::cohere::CompletionModel; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; +use crate::{json_utils, streaming}; +use async_stream::stream; +use futures::StreamExt; +use serde::Deserialize; +use serde_json::json; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case", tag = "type")] +enum StreamingEvent { + MessageStart, + ContentStart, + ContentDelta { delta: Option }, + ContentEnd, + ToolPlan, + ToolCallStart { delta: Option }, + ToolCallDelta { delta: Option }, + ToolCallEnd, + MessageEnd { delta: Option }, +} + +#[derive(Debug, Deserialize)] +struct MessageContentDelta { + text: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageToolFunctionDelta { + name: Option, + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageToolCallDelta { + id: Option, + function: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageDelta { + content: Option, + tool_calls: Option, +} + +#[derive(Debug, Deserialize)] +struct Delta { + message: Option, +} + +#[derive(Debug, Deserialize)] +struct MessageEndDelta { + usage: Option, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: Option, +} + +impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> + { + let request = self.create_completion_request(request)?; + let request = json_utils::merge(request, json!({"stream": true})); + + tracing::debug!( + "Cohere request: {}", + serde_json::to_string_pretty(&request)? + ); + + let response = self.client.post("/v2/chat").json(&request).send().await?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError(format!( + "{}: {}", + response.status(), + response.text().await? + ))); + } + + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + let mut current_tool_call: Option<(String, String, String)> = None; + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(CompletionError::from(e)); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + yield Err(CompletionError::ResponseError(e.to_string())); + break; + } + }; + + for line in text.lines() { + + let Some(line) = line.strip_prefix("data: ") else { + continue; + }; + + let event = { + let result = serde_json::from_str::(line); + + let Ok(event) = result else { + continue; + }; + + event + }; + + match event { + StreamingEvent::ContentDelta { delta: Some(delta) } => { + let Some(message) = &delta.message else { continue; }; + let Some(content) = &message.content else { continue; }; + let Some(text) = &content.text else { continue; }; + + yield Ok(RawStreamingChoice::Message(text.clone())); + }, + StreamingEvent::MessageEnd {delta: Some(delta)} => { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: delta.usage.clone() + })); + }, + StreamingEvent::ToolCallStart { delta: Some(delta)} => { + // Skip the delta if there's any missing information, + // though this *should* all be present + let Some(message) = &delta.message else { continue; }; + let Some(tool_calls) = &message.tool_calls else { continue; }; + let Some(id) = tool_calls.id.clone() else { continue; }; + let Some(function) = &tool_calls.function else { continue; }; + let Some(name) = function.name.clone() else { continue; }; + let Some(arguments) = function.arguments.clone() else { continue; }; + + current_tool_call = Some((id, name, arguments)); + }, + StreamingEvent::ToolCallDelta { delta: Some(delta)} => { + // Skip the delta if there's any missing information, + // though this *should* all be present + let Some(message) = &delta.message else { continue; }; + let Some(tool_calls) = &message.tool_calls else { continue; }; + let Some(function) = &tool_calls.function else { continue; }; + let Some(arguments) = function.arguments.clone() else { continue; }; + + if let Some(tc) = current_tool_call.clone() { + current_tool_call = Some(( + tc.0, + tc.1, + format!("{}{}", tc.2, arguments) + )); + }; + }, + StreamingEvent::ToolCallEnd => { + let Some(tc) = current_tool_call.clone() else { continue; }; + + let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; + + yield Ok(RawStreamingChoice::ToolCall( + tc.0, + tc.1, + args + )); + + current_tool_call = None; + }, + _ => {} + }; + } + } + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) + } +} diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 86d9a0f..b6302a0 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -10,8 +10,9 @@ //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ completion::{self, CompletionError, CompletionModel, CompletionRequest}, extractor::ExtractorBuilder, @@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel { } impl StreamingCompletionModel for DeepSeekCompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/v1/chat/completions").json(&request); send_compatible_streaming_request(builder).await diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 3b536c9..77341d5 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -13,7 +13,7 @@ use super::openai; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type StreamingResponse = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index d9d35c2..7491fc5 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -609,7 +609,7 @@ pub mod gemini_api_types { HarmCategoryCivicIntegrity, } - #[derive(Debug, Deserialize)] + #[derive(Debug, Deserialize, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 362d9ae..e48c0b8 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -2,12 +2,17 @@ use async_stream::stream; use futures::StreamExt; use serde::Deserialize; +use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; use crate::{ completion::{CompletionError, CompletionRequest}, - streaming::{self, StreamingCompletionModel, StreamingResult}, + streaming::{self, StreamingCompletionModel}, }; -use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; +#[derive(Debug, Deserialize, Default, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PartialUsage { + pub total_token_count: i32, +} #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -15,13 +20,21 @@ pub struct StreamGenerateContentResponse { /// Candidate responses from the model. pub candidates: Vec, pub model_version: Option, + pub usage_metadata: Option, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage_metadata: PartialUsage, } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let request = create_request_body(completion_request)?; let response = self @@ -42,7 +55,7 @@ impl StreamingCompletionModel for CompletionModel { ))); } - Ok(Box::pin(stream! { + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { @@ -74,13 +87,23 @@ impl StreamingCompletionModel for CompletionModel { match choice.content.parts.first() { super::completion::gemini_api_types::Part::Text(text) - => yield Ok(streaming::StreamingChoice::Message(text)), + => yield Ok(streaming::RawStreamingChoice::Message(text)), super::completion::gemini_api_types::Part::FunctionCall(function_call) - => yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), + => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), _ => panic!("Unsupported response type with streaming.") }; + + if choice.finish_reason.is_some() { + yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage_metadata: PartialUsage { + total_token_count: data.usage_metadata.unwrap().total_token_count, + } + })) + } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index f852e77..ffdf29a 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -10,7 +10,8 @@ //! ``` use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use crate::json_utils::merge; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type StreamingResponse = openai::StreamingCompletionResponse; + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index 8eea985..90ddd39 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -1,9 +1,9 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; -use crate::json_utils; use crate::json_utils::merge_inplace; -use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse}; +use crate::streaming::StreamingCompletionModel; +use crate::{json_utils, streaming}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::convert::Infallible; @@ -55,14 +55,19 @@ struct CompletionChunk { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let mut request = self.create_request_body(&completion_request)?; // Enable streaming - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); if let Some(ref params) = completion_request.additional_params { merge_inplace(&mut request, params.clone()); diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 1606635..e0a69e2 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -12,7 +12,7 @@ use super::openai::{send_compatible_streaming_request, AssistantContent}; use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; - merge_inplace(&mut request, json!({"stream": true})); + merge_inplace( + &mut request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); @@ -526,8 +530,10 @@ mod image_generation { // ====================================== // Hyperbolic Audio Generation API // ====================================== +use crate::providers::openai; #[cfg(feature = "audio")] pub use audio_generation::*; + #[cfg(feature = "audio")] mod audio_generation { use super::{ApiResponse, Client}; diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 1d11c19..16687bf 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -8,8 +8,9 @@ //! //! ``` use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index 278863a..9023133 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,7 +11,7 @@ use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type StreamingResponse = openai::StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { let mut request = self.create_completion_request(request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 4b539b8..52df410 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,7 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::json_utils::merge_inplace; -use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, @@ -47,7 +47,7 @@ use crate::{ extractor::ExtractorBuilder, json_utils, message, message::{ImageDetail, Text}, - Embed, OneOrMany, + streaming, Embed, OneOrMany, }; use async_stream::stream; use futures::StreamExt; @@ -405,8 +405,25 @@ impl completion::CompletionModel for CompletionModel { } } +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub done_reason: Option, + pub total_duration: Option, + pub load_duration: Option, + pub prompt_eval_count: Option, + pub prompt_eval_duration: Option, + pub eval_count: Option, + pub eval_duration: Option, +} + impl StreamingCompletionModel for CompletionModel { - async fn stream(&self, request: CompletionRequest) -> Result { + type StreamingResponse = StreamingCompletionResponse; + + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> + { let mut request_payload = self.create_completion_request(request)?; merge_inplace(&mut request_payload, json!({"stream": true})); @@ -426,7 +443,7 @@ impl StreamingCompletionModel for CompletionModel { return Err(CompletionError::ProviderError(err_text)); } - Ok(Box::pin(stream! { + let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { @@ -456,22 +473,36 @@ impl StreamingCompletionModel for CompletionModel { match response.message { Message::Assistant{ content, tool_calls, .. } => { if !content.is_empty() { - yield Ok(StreamingChoice::Message(content)) + yield Ok(RawStreamingChoice::Message(content)) } for tool_call in tool_calls.iter() { let function = tool_call.function.clone(); - yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); + yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); } } _ => { continue; } } + + if response.done { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + total_duration: response.total_duration, + load_duration: response.load_duration, + prompt_eval_count: response.prompt_eval_count, + prompt_eval_duration: response.prompt_eval_duration, + eval_count: response.eval_count, + eval_duration: response.eval_duration, + done_reason: response.done_reason, + })); + } } } - })) + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) } } diff --git a/rig-core/src/providers/openai/streaming.rs b/rig-core/src/providers/openai/streaming.rs index 3727468..ad6ac41 100644 --- a/rig-core/src/providers/openai/streaming.rs +++ b/rig-core/src/providers/openai/streaming.rs @@ -2,14 +2,16 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils; use crate::json_utils::merge; +use crate::providers::openai::Usage; use crate::streaming; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use async_stream::stream; use futures::StreamExt; use reqwest::RequestBuilder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; +use tracing::debug; // ================================================================ // OpenAI Completion Streaming API @@ -25,10 +27,11 @@ pub struct StreamingFunction { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct StreamingToolCall { pub index: usize, + pub id: Option, pub function: StreamingFunction, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct StreamingDelta { #[serde(default)] content: Option, @@ -36,23 +39,34 @@ struct StreamingDelta { tool_calls: Vec, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct StreamingChoice { delta: StreamingDelta, } -#[derive(Deserialize)] -struct StreamingCompletionResponse { +#[derive(Deserialize, Debug)] +struct StreamingCompletionChunk { choices: Vec, + usage: Option, +} + +#[derive(Clone)] +pub struct StreamingCompletionResponse { + pub usage: Usage, } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> + { let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); let builder = self.client.post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await @@ -61,7 +75,7 @@ impl StreamingCompletionModel for CompletionModel { pub async fn send_compatible_streaming_request( request_builder: RequestBuilder, -) -> Result { +) -> Result, CompletionError> { let response = request_builder.send().await?; if !response.status().is_success() { @@ -73,11 +87,16 @@ pub async fn send_compatible_streaming_request( } // Handle OpenAI Compatible SSE chunks - Ok(Box::pin(stream! { + let inner = Box::pin(stream! { let mut stream = response.bytes_stream(); + let mut final_usage = Usage { + prompt_tokens: 0, + total_tokens: 0 + }; + let mut partial_data = None; - let mut calls: HashMap = HashMap::new(); + let mut calls: HashMap = HashMap::new(); while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { @@ -100,8 +119,6 @@ pub async fn send_compatible_streaming_request( for line in text.lines() { let mut line = line.to_string(); - - // If there was a remaining part, concat with current line if partial_data.is_some() { line = format!("{}{}", partial_data.unwrap(), line); @@ -121,64 +138,85 @@ pub async fn send_compatible_streaming_request( } } - let data = serde_json::from_str::(&line); + let data = serde_json::from_str::(&line); let Ok(data) = data else { + let err = data.unwrap_err(); + debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err); continue; }; - let choice = data.choices.first().expect("Should have at least one choice"); - let delta = &choice.delta; + if let Some(choice) = data.choices.first() { - if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let function = tool_call.function.clone(); + let delta = &choice.delta; - // Start of tool call - // name: Some(String) - // arguments: None - if function.name.is_some() && function.arguments.is_empty() { - calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); + if !delta.tool_calls.is_empty() { + for tool_call in &delta.tool_calls { + let function = tool_call.function.clone(); + // Start of tool call + // name: Some(String) + // arguments: None + if function.name.is_some() && function.arguments.is_empty() { + let id = tool_call.id.clone().unwrap_or("".to_string()); + + calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string())); + } + // Part of tool call + // name: None + // arguments: Some(String) + else if function.name.is_none() && !function.arguments.is_empty() { + let Some((id, name, arguments)) = calls.get(&tool_call.index) else { + debug!("Partial tool call received but tool call was never started."); + continue; + }; + + let new_arguments = &tool_call.function.arguments; + let arguments = format!("{}{}", arguments, new_arguments); + + calls.insert(tool_call.index, (id.clone(), name.clone(), arguments)); + } + // Entire tool call + else { + let id = tool_call.id.clone().unwrap_or("".to_string()); + let name = function.name.expect("function name should be present for complete tool call"); + let arguments = function.arguments; + let Ok(arguments) = serde_json::from_str(&arguments) else { + debug!("Couldn't serialize '{}' as a json value", arguments); + continue; + }; + + yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments)) + } } - // Part of tool call - // name: None - // arguments: Some(String) - else if function.name.is_none() && !function.arguments.is_empty() { - let Some((name, arguments)) = calls.get(&tool_call.index) else { - continue; - }; + } - let new_arguments = &tool_call.function.arguments; - let arguments = format!("{}{}", arguments, new_arguments); - - calls.insert(tool_call.index, (name.clone(), arguments)); - } - // Entire tool call - else { - let name = function.name.unwrap(); - let arguments = function.arguments; - let Ok(arguments) = serde_json::from_str(&arguments) else { - continue; - }; - - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) - } + if let Some(content) = &choice.delta.content { + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) } } - if let Some(content) = &choice.delta.content { - yield Ok(streaming::StreamingChoice::Message(content.clone())) + + if let Some(usage) = data.usage { + final_usage = usage.clone(); } } } - for (_, (name, arguments)) in calls { + for (_, (id, name, arguments)) in calls { let Ok(arguments) = serde_json::from_str(&arguments) else { continue; }; - yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) + println!("{id} {name}"); + + yield Ok(RawStreamingChoice::ToolCall(id, name, arguments)) } - })) + + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: final_usage.clone() + })) + }); + + Ok(streaming::StreamingCompletionResponse::new(inner)) } diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs new file mode 100644 index 0000000..2eab6c1 --- /dev/null +++ b/rig-core/src/providers/openrouter/client.rs @@ -0,0 +1,125 @@ +use crate::{agent::AgentBuilder, extractor::ExtractorBuilder}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::completion::CompletionModel; + +// ================================================================ +// Main openrouter Client +// ================================================================ +const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new OpenRouter client with the given API key. + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, OPENROUTER_API_BASE_URL) + } + + /// Create a new OpenRouter client with the given API key and base API URL. + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("OpenRouter reqwest client should build"), + } + } + + /// Create a new openrouter client from the `openrouter_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set"); + Self::new(&api_key) + } + + pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::openrouter::{Client, self}; + /// + /// // Initialize the openrouter client + /// let openrouter = Client::new("your-openrouter-api-key"); + /// + /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B); + /// ``` + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::openrouter::{Client, self}; + /// + /// // Initialize the Eternal client + /// let openrouter = Client::new("your-openrouter-api-key"); + /// + /// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub(crate) message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} diff --git a/rig-core/src/providers/openrouter.rs b/rig-core/src/providers/openrouter/completion.rs similarity index 57% rename from rig-core/src/providers/openrouter.rs rename to rig-core/src/providers/openrouter/completion.rs index 149b2ff..85b9bfe 100644 --- a/rig-core/src/providers/openrouter.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -1,147 +1,16 @@ -//! OpenRouter Inference API client and Rig integration -//! -//! # Example -//! ``` -//! use rig::providers::openrouter; -//! -//! let client = openrouter::Client::new("YOUR_API_KEY"); -//! -//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B); -//! ``` +use serde::Deserialize; + +use super::client::{ApiErrorResponse, ApiResponse, Client, Usage}; use crate::{ - agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - extractor::ExtractorBuilder, json_utils, providers::openai::Message, OneOrMany, }; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; -use super::openai::AssistantContent; - -// ================================================================ -// Main openrouter Client -// ================================================================ -const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; - -#[derive(Clone)] -pub struct Client { - base_url: String, - http_client: reqwest::Client, -} - -impl Client { - /// Create a new OpenRouter client with the given API key. - pub fn new(api_key: &str) -> Self { - Self::from_url(api_key, OPENROUTER_API_BASE_URL) - } - - /// Create a new OpenRouter client with the given API key and base API URL. - pub fn from_url(api_key: &str, base_url: &str) -> Self { - Self { - base_url: base_url.to_string(), - http_client: reqwest::Client::builder() - .default_headers({ - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "Authorization", - format!("Bearer {}", api_key) - .parse() - .expect("Bearer token should parse"), - ); - headers - }) - .build() - .expect("OpenRouter reqwest client should build"), - } - } - - /// Create a new openrouter client from the `openrouter_API_KEY` environment variable. - /// Panics if the environment variable is not set. - pub fn from_env() -> Self { - let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set"); - Self::new(&api_key) - } - - fn post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url) - } - - /// Create a completion model with the given name. - /// - /// # Example - /// ``` - /// use rig::providers::openrouter::{Client, self}; - /// - /// // Initialize the openrouter client - /// let openrouter = Client::new("your-openrouter-api-key"); - /// - /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B); - /// ``` - pub fn completion_model(&self, model: &str) -> CompletionModel { - CompletionModel::new(self.clone(), model) - } - - /// Create an agent builder with the given completion model. - /// - /// # Example - /// ``` - /// use rig::providers::openrouter::{Client, self}; - /// - /// // Initialize the Eternal client - /// let openrouter = Client::new("your-openrouter-api-key"); - /// - /// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B) - /// .preamble("You are comedian AI with a mission to make people laugh.") - /// .temperature(0.0) - /// .build(); - /// ``` - pub fn agent(&self, model: &str) -> AgentBuilder { - AgentBuilder::new(self.completion_model(model)) - } - - /// Create an extractor builder with the given completion model. - pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( - &self, - model: &str, - ) -> ExtractorBuilder { - ExtractorBuilder::new(self.completion_model(model)) - } -} - -#[derive(Debug, Deserialize)] -struct ApiErrorResponse { - message: String, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -enum ApiResponse { - Ok(T), - Err(ApiErrorResponse), -} - -#[derive(Clone, Debug, Deserialize)] -pub struct Usage { - pub prompt_tokens: usize, - pub completion_tokens: usize, - pub total_tokens: usize, -} - -impl std::fmt::Display for Usage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Prompt tokens: {} Total tokens: {}", - self.prompt_tokens, self.total_tokens - ) - } -} +use crate::providers::openai::AssistantContent; // ================================================================ // OpenRouter Completion API @@ -241,7 +110,7 @@ pub struct Choice { #[derive(Clone)] pub struct CompletionModel { - client: Client, + pub(crate) client: Client, /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) pub model: String, } @@ -253,16 +122,11 @@ impl CompletionModel { model: model.to_string(), } } -} -impl completion::CompletionModel for CompletionModel { - type Response = CompletionResponse; - - #[cfg_attr(feature = "worker", worker::send)] - async fn completion( + pub(crate) fn create_completion_request( &self, completion_request: CompletionRequest, - ) -> Result, CompletionError> { + ) -> Result { // Add preamble to chat history (if available) let mut full_history: Vec = match &completion_request.preamble { Some(preamble) => vec![Message::system(preamble)], @@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel { "temperature": completion_request.temperature, }); + let request = if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }; + + Ok(request) + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let request = self.create_completion_request(completion_request)?; + let response = self .client .post("/chat/completions") - .json( - &if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }, - ) + .json(&request) .send() .await?; diff --git a/rig-core/src/providers/openrouter/mod.rs b/rig-core/src/providers/openrouter/mod.rs new file mode 100644 index 0000000..9983815 --- /dev/null +++ b/rig-core/src/providers/openrouter/mod.rs @@ -0,0 +1,17 @@ +//! OpenRouter Inference API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::openrouter; +//! +//! let client = openrouter::Client::new("YOUR_API_KEY"); +//! +//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B); +//! ``` + +pub mod client; +pub mod completion; +pub mod streaming; + +pub use client::*; +pub use completion::*; diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs new file mode 100644 index 0000000..1704dcb --- /dev/null +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -0,0 +1,313 @@ +use std::collections::HashMap; + +use crate::{ + json_utils, + message::{ToolCall, ToolFunction}, + streaming::{self}, +}; +use async_stream::stream; +use futures::StreamExt; +use reqwest::RequestBuilder; +use serde_json::{json, Value}; + +use crate::{ + completion::{CompletionError, CompletionRequest}, + streaming::StreamingCompletionModel, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamingCompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamingChoice { + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub native_finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub index: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MessageResponse { + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(default)] + pub tool_calls: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OpenRouterToolFunction { + pub name: Option, + pub arguments: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OpenRouterToolCall { + pub index: usize, + pub id: Option, + pub r#type: Option, + pub function: OpenRouterToolFunction, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct ResponseUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ErrorResponse { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct DeltaResponse { + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(default)] + pub tool_calls: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub native_finish_reason: Option, +} + +#[derive(Clone)] +pub struct FinalCompletionResponse { + pub usage: ResponseUsage, +} + +impl StreamingCompletionModel for super::CompletionModel { + type StreamingResponse = FinalCompletionResponse; + + async fn stream( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> + { + let request = self.create_completion_request(completion_request)?; + + let request = json_utils::merge(request, json!({"stream": true})); + + let builder = self.client.post("/chat/completions").json(&request); + + send_streaming_request(builder).await + } +} + +pub async fn send_streaming_request( + request_builder: RequestBuilder, +) -> Result, CompletionError> { + let response = request_builder.send().await?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError(format!( + "{}: {}", + response.status(), + response.text().await? + ))); + } + + // Handle OpenAI Compatible SSE chunks + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + let mut tool_calls = HashMap::new(); + let mut partial_line = String::new(); + let mut final_usage = None; + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(CompletionError::from(e)); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + yield Err(CompletionError::ResponseError(e.to_string())); + break; + } + }; + + for line in text.lines() { + let mut line = line.to_string(); + + // Skip empty lines and processing messages, as well as [DONE] (might be useful though) + if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" { + continue; + } + + // Handle data: prefix + line = line.strip_prefix("data: ").unwrap_or(&line).to_string(); + + // If line starts with { but doesn't end with }, it's a partial JSON + if line.starts_with('{') && !line.ends_with('}') { + partial_line = line; + continue; + } + + // If we have a partial line and this line ends with }, complete it + if !partial_line.is_empty() { + if line.ends_with('}') { + partial_line.push_str(&line); + line = partial_line; + partial_line = String::new(); + } else { + partial_line.push_str(&line); + continue; + } + } + + let data = match serde_json::from_str::(&line) { + Ok(data) => data, + Err(_) => { + continue; + } + }; + + + let choice = data.choices.first().expect("Should have at least one choice"); + + // TODO this has to handle outputs like this: + // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}] + // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}] + // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}] + if let Some(delta) = &choice.delta { + if !delta.tool_calls.is_empty() { + for tool_call in &delta.tool_calls { + let index = tool_call.index; + + // Get or create tool call entry + let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall { + id: String::new(), + function: ToolFunction { + name: String::new(), + arguments: serde_json::Value::Null, + }, + }); + + // Update fields if present + if let Some(id) = &tool_call.id { + if !id.is_empty() { + existing_tool_call.id = id.clone(); + } + } + if let Some(name) = &tool_call.function.name { + if !name.is_empty() { + existing_tool_call.function.name = name.clone(); + } + } + if let Some(chunk) = &tool_call.function.arguments { + // Convert current arguments to string if needed + let current_args = match &existing_tool_call.function.arguments { + serde_json::Value::Null => String::new(), + serde_json::Value::String(s) => s.clone(), + v => v.to_string(), + }; + + // Concatenate the new chunk + let combined = format!("{}{}", current_args, chunk); + + // Try to parse as JSON if it looks complete + if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { + match serde_json::from_str(&combined) { + Ok(parsed) => existing_tool_call.function.arguments = parsed, + Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined), + } + } else { + existing_tool_call.function.arguments = serde_json::Value::String(combined); + } + } + } + } + + if let Some(content) = &delta.content { + if !content.is_empty() { + yield Ok(streaming::RawStreamingChoice::Message(content.clone())) + } + } + + if let Some(usage) = data.usage { + final_usage = Some(usage); + } + } + + // Handle message format + if let Some(message) = &choice.message { + if !message.tool_calls.is_empty() { + for tool_call in &message.tool_calls { + let name = tool_call.function.name.clone(); + let id = tool_call.id.clone(); + let arguments = if let Some(args) = &tool_call.function.arguments { + // Try to parse the string as JSON, fallback to string value + match serde_json::from_str(args) { + Ok(v) => v, + Err(_) => serde_json::Value::String(args.to_string()), + } + } else { + serde_json::Value::Null + }; + let index = tool_call.index; + + tool_calls.insert(index, ToolCall{ + id: id.unwrap_or_default(), + function: ToolFunction { + name: name.unwrap_or_default(), + arguments, + }, + }); + } + } + + if !message.content.is_empty() { + yield Ok(streaming::RawStreamingChoice::Message(message.content.clone())) + } + } + } + } + + for (_, tool_call) in tool_calls.into_iter() { + + yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments)); + } + + yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse { + usage: final_usage.unwrap_or_default() + })) + + }); + + Ok(streaming::StreamingCompletionResponse::new(stream)) +} diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 1b3d1ea..c09c4d6 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -18,8 +18,9 @@ use crate::{ use crate::completion::CompletionRequest; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel { } impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: completion::CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index dffb3b8..f437828 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -1,18 +1,21 @@ use serde_json::json; use super::completion::CompletionModel; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; +use crate::streaming::StreamingCompletionResponse; use crate::{ completion::{CompletionError, CompletionRequest}, json_utils::merge, - streaming::{StreamingCompletionModel, StreamingResult}, + streaming::StreamingCompletionModel, }; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream_tokens": true})); diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index 80ab430..b87a53e 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -1,15 +1,17 @@ use crate::completion::{CompletionError, CompletionRequest}; use crate::json_utils::merge; +use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; use crate::providers::xai::completion::CompletionModel; -use crate::streaming::{StreamingCompletionModel, StreamingResult}; +use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse}; use serde_json::json; impl StreamingCompletionModel for CompletionModel { + type StreamingResponse = openai::StreamingCompletionResponse; async fn stream( &self, completion_request: CompletionRequest, - ) -> Result { + ) -> Result, CompletionError> { let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 76f91c4..7a58498 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -11,59 +11,150 @@ use crate::agent::Agent; use crate::completion::{ - CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, + CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, + CompletionResponse, Message, }; +use crate::message::{AssistantContent, ToolCall, ToolFunction}; +use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; -use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; +use std::task::{Context, Poll}; /// Enum representing a streaming chunk from the model -#[derive(Debug)] -pub enum StreamingChoice { +#[derive(Debug, Clone)] +pub enum RawStreamingChoice { /// A text chunk from a message response Message(String), /// A tool call response chunk ToolCall(String, String, serde_json::Value), + + /// The final response object, must be yielded if you want the + /// `response` field to be populated on the `StreamingCompletionResponse` + FinalResponse(R), } -impl Display for StreamingChoice { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - StreamingChoice::Message(text) => write!(f, "{}", text), - StreamingChoice::ToolCall(name, id, params) => { - write!(f, "Tool call: {} {} {:?}", name, id, params) - } +#[cfg(not(target_arch = "wasm32"))] +pub type StreamingResult = + Pin, CompletionError>> + Send>>; + +#[cfg(target_arch = "wasm32")] +pub type StreamingResult = + Pin, CompletionError>>>>; + +/// The response from a streaming completion request; +/// message and response are populated at the end of the +/// `inner` stream. +pub struct StreamingCompletionResponse { + inner: StreamingResult, + text: String, + tool_calls: Vec, + /// The final aggregated message from the stream + /// contains all text and tool calls generated + pub choice: OneOrMany, + /// The final response from the stream, may be `None` + /// if the provider didn't yield it during the stream + pub response: Option, +} + +impl StreamingCompletionResponse { + pub fn new(inner: StreamingResult) -> StreamingCompletionResponse { + Self { + inner, + text: "".to_string(), + tool_calls: vec![], + choice: OneOrMany::one(AssistantContent::text("")), + response: None, } } } -#[cfg(not(target_arch = "wasm32"))] -pub type StreamingResult = - Pin> + Send>>; +impl From> for CompletionResponse> { + fn from(value: StreamingCompletionResponse) -> CompletionResponse> { + CompletionResponse { + choice: value.choice, + raw_response: value.response, + } + } +} -#[cfg(target_arch = "wasm32")] -pub type StreamingResult = Pin>>>; +impl Stream for StreamingCompletionResponse { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream = self.get_mut(); + + match stream.inner.as_mut().poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + // This is run at the end of the inner stream to collect all tokens into + // a single unified `Message`. + let mut choice = vec![]; + + stream.tool_calls.iter().for_each(|tc| { + choice.push(AssistantContent::ToolCall(tc.clone())); + }); + + // This is required to ensure there's always at least one item in the content + if choice.is_empty() || !stream.text.is_empty() { + choice.insert(0, AssistantContent::text(stream.text.clone())); + } + + stream.choice = OneOrMany::many(choice) + .expect("There should be at least one assistant message"); + + Poll::Ready(None) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(choice))) => match choice { + RawStreamingChoice::Message(text) => { + // Forward the streaming tokens to the outer stream + // and concat the text together + stream.text = format!("{}{}", stream.text, text.clone()); + Poll::Ready(Some(Ok(AssistantContent::text(text)))) + } + RawStreamingChoice::ToolCall(id, name, args) => { + // Keep track of each tool call to aggregate the final message later + // and pass it to the outer stream + stream.tool_calls.push(ToolCall { + id: id.clone(), + function: ToolFunction { + name: name.clone(), + arguments: args.clone(), + }, + }); + Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args)))) + } + RawStreamingChoice::FinalResponse(response) => { + // Set the final response field and return the next item in the stream + stream.response = Some(response); + + stream.poll_next_unpin(cx) + } + }, + } + } +} /// Trait for high-level streaming prompt interface -pub trait StreamingPrompt: Send + Sync { +pub trait StreamingPrompt: Send + Sync { /// Stream a simple prompt to the model fn stream_prompt( &self, prompt: &str, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync { +pub trait StreamingChat: Send + Sync { /// Stream a chat with history to the model fn stream_chat( &self, prompt: &str, chat_history: Vec, - ) -> impl Future>; + ) -> impl Future, CompletionError>>; } /// Trait for low-level streaming completion interface @@ -78,29 +169,35 @@ pub trait StreamingCompletion { /// Trait defining a streaming completion model pub trait StreamingCompletionModel: CompletionModel { + type StreamingResponse: Clone + Unpin; /// Stream a completion response for the given request fn stream( &self, request: CompletionRequest, - ) -> impl Future>; + ) -> impl Future< + Output = Result, CompletionError>, + >; } /// helper function to stream a completion request to stdout pub async fn stream_to_stdout( agent: Agent, - stream: &mut StreamingResult, + stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> { print!("Response: "); while let Some(chunk) = stream.next().await { match chunk { - Ok(StreamingChoice::Message(text)) => { - print!("{}", text); + Ok(AssistantContent::Text(text)) => { + print!("{}", text.text); std::io::Write::flush(&mut std::io::stdout())?; } - Ok(StreamingChoice::ToolCall(name, _, params)) => { + Ok(AssistantContent::ToolCall(tool_call)) => { let res = agent .tools - .call(&name, params.to_string()) + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) .await .map_err(|e| std::io::Error::other(e.to_string()))?; println!("\nResult: {}", res); @@ -111,6 +208,7 @@ pub async fn stream_to_stdout( } } } + println!(); // New line after streaming completes Ok(())