This commit is contained in:
yavens 2025-04-18 10:12:21 +07:00 committed by GitHub
commit d3f6857019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1437 additions and 317 deletions

View File

@ -19,5 +19,11 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -0,0 +1,27 @@
use rig::providers::cohere;
use rig::streaming::{stream_to_stdout, StreamingPrompt};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = cohere::Client::from_env()
.agent(cohere::COMMAND)
.preamble("Be precise and concise.")
.temperature(0.5)
.build();
// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await?;
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -0,0 +1,118 @@
use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::cohere::Client::from_env()
.agent(providers::cohere::COMMAND_R)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!(
"Usage: {:?} tokens",
response.usage_metadata.total_token_count
);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,15 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!(
"Usage: {:?} tokens",
response.usage_metadata.total_token_count
);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -17,5 +17,11 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(()) Ok(())
} }

View File

@ -0,0 +1,118 @@
use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::openrouter::Client::from_env()
.agent(providers::openrouter::GEMINI_FLASH_2_0)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -110,23 +110,20 @@ use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt}; use futures::{stream, StreamExt, TryStreamExt};
use crate::streaming::StreamingCompletionResponse;
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use crate::{ use crate::{
completion::{ completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError, Message, Prompt, PromptError,
}, },
message::AssistantContent, message::AssistantContent,
streaming::{ streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet}, tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn}, vector_store::{VectorStoreError, VectorStoreIndexDyn},
}; };
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools. /// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted. /// All context documents and tools are always provided to the agent when prompted.
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
} }
} }
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> { impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> { async fn stream_prompt(
&self,
prompt: &str,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_chat(prompt, vec![]).await self.stream_chat(prompt, vec![]).await
} }
} }
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> { impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
async fn stream_chat( async fn stream_chat(
&self, &self,
prompt: &str, prompt: &str,
chat_history: Vec<Message>, chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_completion(prompt, chat_history) self.stream_completion(prompt, chat_history)
.await? .await?
.stream() .stream()

View File

@ -67,7 +67,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::OneOrMany; use crate::OneOrMany;
use crate::{ use crate::{
json_utils, json_utils,
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> { impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
/// Stream the completion request /// Stream the completion request
pub async fn stream(self) -> Result<StreamingResult, CompletionError> { pub async fn stream(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
let model = self.model.clone(); let model = self.model.clone();
model.stream(self.build()).await model.stream(self.build()).await
} }

View File

@ -8,7 +8,8 @@ use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::message::MessageError; use crate::message::MessageError;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
@ -61,7 +62,7 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>, pub stop_sequence: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Clone)]
pub struct PartialUsage { pub struct PartialUsage {
pub output_tokens: usize, pub output_tokens: usize,
#[serde(default)] #[serde(default)]
@ -75,11 +76,18 @@ struct ToolCallState {
input_json: String, input_json: String,
} }
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: PartialUsage,
}
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let max_tokens = if let Some(tokens) = completion_request.max_tokens { let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens tokens
} else if let Some(tokens) = self.default_max_tokens { } else if let Some(tokens) = self.default_max_tokens {
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
// Use our SSE decoder to directly handle Server-Sent Events format // Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response); let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! { let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
let mut current_tool_call: Option<ToolCallState> = None; let mut current_tool_call: Option<ToolCallState> = None;
let mut sse_stream = Box::pin(sse_stream); let mut sse_stream = Box::pin(sse_stream);
let mut input_tokens = 0;
while let Some(sse_result) = sse_stream.next().await { while let Some(sse_result) = sse_stream.next().await {
match sse_result { match sse_result {
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
// Parse the SSE data as a StreamingEvent // Parse the SSE data as a StreamingEvent
match serde_json::from_str::<StreamingEvent>(&sse.data) { match serde_json::from_str::<StreamingEvent>(&sse.data) {
Ok(event) => { Ok(event) => {
match &event {
StreamingEvent::MessageStart { message } => {
input_tokens = message.usage.input_tokens;
},
StreamingEvent::MessageDelta { delta, usage } => {
if delta.stop_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: PartialUsage {
output_tokens: usage.output_tokens,
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
}
}))
}
}
_ => {}
}
if let Some(result) = handle_event(&event, &mut current_tool_call) { if let Some(result) = handle_event(&event, &mut current_tool_call) {
yield result; yield result;
} }
@ -184,19 +211,21 @@ impl StreamingCompletionModel for CompletionModel {
} }
} }
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }
fn handle_event( fn handle_event(
event: &StreamingEvent, event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>, current_tool_call: &mut Option<ToolCallState>,
) -> Option<Result<StreamingChoice, CompletionError>> { ) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
match event { match event {
StreamingEvent::ContentBlockDelta { delta, .. } => match delta { StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {
if current_tool_call.is_none() { if current_tool_call.is_none() {
return Some(Ok(StreamingChoice::Message(text.clone()))); return Some(Ok(RawStreamingChoice::Message(text.clone())));
} }
None None
} }
@ -227,7 +256,7 @@ fn handle_event(
&tool_call.input_json &tool_call.input_json
}; };
match serde_json::from_str(json_str) { match serde_json::from_str(json_str) {
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall( Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
tool_call.name, tool_call.name,
tool_call.id, tool_call.id,
json_value, json_value,

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
// Azure OpenAI Streaming API // Azure OpenAI Streaming API
// ----------------------------------------------------- // -----------------------------------------------------
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self let builder = self
.client .client

View File

@ -6,8 +6,9 @@ use crate::{
}; };
use super::client::Client; use super::client::Client;
use crate::completion::CompletionRequest;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::{json, Value};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CompletionResponse { pub struct CompletionResponse {
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
#[derive(Clone)] #[derive(Clone)]
pub struct CompletionModel { pub struct CompletionModel {
client: Client, pub(crate) client: Client,
pub model: String, pub model: String,
} }
@ -430,16 +431,11 @@ impl CompletionModel {
model: model.to_string(), model: model.to_string(),
} }
} }
}
impl completion::CompletionModel for CompletionModel { pub(crate) fn create_completion_request(
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self, &self,
completion_request: completion::CompletionRequest, completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> { ) -> Result<Value, CompletionError> {
let prompt = completion_request.prompt_with_context(); let prompt = completion_request.prompt_with_context();
let mut messages: Vec<message::Message> = let mut messages: Vec<message::Message> =
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(), "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
}); });
if let Some(ref params) = completion_request.additional_params {
Ok(json_utils::merge(request.clone(), params.clone()))
} else {
Ok(request)
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
tracing::debug!( tracing::debug!(
"Cohere request: {}", "Cohere request: {}",
serde_json::to_string_pretty(&request)? serde_json::to_string_pretty(&request)?
); );
let response = self let response = self.client.post("/v2/chat").json(&request).send().await?;
.client
.post("/v2/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
if response.status().is_success() { if response.status().is_success() {
let text_response = response.text().await?; let text_response = response.text().await?;

View File

@ -12,6 +12,7 @@
pub mod client; pub mod client;
pub mod completion; pub mod completion;
pub mod embeddings; pub mod embeddings;
pub mod streaming;
pub use client::Client; pub use client::Client;
pub use client::{ApiErrorResponse, ApiResponse}; pub use client::{ApiErrorResponse, ApiResponse};
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
// ================================================================ // ================================================================
/// `command-r-plus` completion model /// `command-r-plus` completion model
pub const COMMAND_R_PLUS: &str = "comman-r-plus"; pub const COMMAND_R_PLUS: &str = "command-r-plus";
/// `command-r` completion model /// `command-r` completion model
pub const COMMAND_R: &str = "command-r"; pub const COMMAND_R: &str = "command-r";
/// `command` completion model /// `command` completion model

View File

@ -0,0 +1,188 @@
use crate::completion::{CompletionError, CompletionRequest};
use crate::providers::cohere::completion::Usage;
use crate::providers::cohere::CompletionModel;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{json_utils, streaming};
use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case", tag = "type")]
enum StreamingEvent {
MessageStart,
ContentStart,
ContentDelta { delta: Option<Delta> },
ContentEnd,
ToolPlan,
ToolCallStart { delta: Option<Delta> },
ToolCallDelta { delta: Option<Delta> },
ToolCallEnd,
MessageEnd { delta: Option<MessageEndDelta> },
}
#[derive(Debug, Deserialize)]
struct MessageContentDelta {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolFunctionDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolCallDelta {
id: Option<String>,
function: Option<MessageToolFunctionDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageDelta {
content: Option<MessageContentDelta>,
tool_calls: Option<MessageToolCallDelta>,
}
#[derive(Debug, Deserialize)]
struct Delta {
message: Option<MessageDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageEndDelta {
usage: Option<Usage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: Option<Usage>,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(request)?;
let request = json_utils::merge(request, json!({"stream": true}));
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self.client.post("/v2/chat").json(&request).send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut current_tool_call: Option<(String, String, String)> = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let Some(line) = line.strip_prefix("data: ") else {
continue;
};
let event = {
let result = serde_json::from_str::<StreamingEvent>(line);
let Ok(event) = result else {
continue;
};
event
};
match event {
StreamingEvent::ContentDelta { delta: Some(delta) } => {
let Some(message) = &delta.message else { continue; };
let Some(content) = &message.content else { continue; };
let Some(text) = &content.text else { continue; };
yield Ok(RawStreamingChoice::Message(text.clone()));
},
StreamingEvent::MessageEnd {delta: Some(delta)} => {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: delta.usage.clone()
}));
},
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(id) = tool_calls.id.clone() else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(name) = function.name.clone() else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
current_tool_call = Some((id, name, arguments));
},
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
if let Some(tc) = current_tool_call.clone() {
current_tool_call = Some((
tc.0,
tc.1,
format!("{}{}", tc.2, arguments)
));
};
},
StreamingEvent::ToolCallEnd => {
let Some(tc) = current_tool_call.clone() else { continue; };
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
yield Ok(RawStreamingChoice::ToolCall(
tc.0,
tc.1,
args
));
current_tool_call = None;
},
_ => {}
};
}
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -10,8 +10,9 @@
//! ``` //! ```
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest}, completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder, extractor::ExtractorBuilder,
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
} }
impl StreamingCompletionModel for DeepSeekCompletionModel { impl StreamingCompletionModel for DeepSeekCompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/v1/chat/completions").json(&request); let builder = self.client.post("/v1/chat/completions").json(&request);
send_compatible_streaming_request(builder).await send_compatible_streaming_request(builder).await

View File

@ -13,7 +13,7 @@
use super::openai; use super::openai;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -609,7 +609,7 @@ pub mod gemini_api_types {
HarmCategoryCivicIntegrity, HarmCategoryCivicIntegrity,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct UsageMetadata { pub struct UsageMetadata {
pub prompt_token_count: i32, pub prompt_token_count: i32,

View File

@ -2,12 +2,17 @@ use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
use serde::Deserialize; use serde::Deserialize;
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
use crate::{ use crate::{
completion::{CompletionError, CompletionRequest}, completion::{CompletionError, CompletionRequest},
streaming::{self, StreamingCompletionModel, StreamingResult}, streaming::{self, StreamingCompletionModel},
}; };
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel}; #[derive(Debug, Deserialize, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PartialUsage {
pub total_token_count: i32,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -15,13 +20,21 @@ pub struct StreamGenerateContentResponse {
/// Candidate responses from the model. /// Candidate responses from the model.
pub candidates: Vec<ContentCandidate>, pub candidates: Vec<ContentCandidate>,
pub model_version: Option<String>, pub model_version: Option<String>,
pub usage_metadata: Option<PartialUsage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage_metadata: PartialUsage,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = create_request_body(completion_request)?; let request = create_request_body(completion_request)?;
let response = self let response = self
@ -42,7 +55,7 @@ impl StreamingCompletionModel for CompletionModel {
))); )));
} }
Ok(Box::pin(stream! { let stream = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
@ -74,13 +87,23 @@ impl StreamingCompletionModel for CompletionModel {
match choice.content.parts.first() { match choice.content.parts.first() {
super::completion::gemini_api_types::Part::Text(text) super::completion::gemini_api_types::Part::Text(text)
=> yield Ok(streaming::StreamingChoice::Message(text)), => yield Ok(streaming::RawStreamingChoice::Message(text)),
super::completion::gemini_api_types::Part::FunctionCall(function_call) super::completion::gemini_api_types::Part::FunctionCall(function_call)
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
_ => panic!("Unsupported response type with streaming.") _ => panic!("Unsupported response type with streaming.")
}; };
if choice.finish_reason.is_some() {
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage_metadata: PartialUsage {
total_token_count: data.usage_metadata.unwrap().total_token_count,
}
}))
}
} }
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }

View File

@ -10,7 +10,8 @@
//! ``` //! ```
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::providers::openai;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -1,9 +1,9 @@
use super::completion::CompletionModel; use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::StreamingCompletionModel;
use crate::{json_utils, streaming};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::convert::Infallible; use std::convert::Infallible;
@ -55,14 +55,19 @@ struct CompletionChunk {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_request_body(&completion_request)?; let mut request = self.create_request_body(&completion_request)?;
// Enable streaming // Enable streaming
merge_inplace(&mut request, json!({"stream": true})); merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
if let Some(ref params) = completion_request.additional_params { if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone()); merge_inplace(&mut request, params.clone());

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent}; use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
merge_inplace(&mut request, json!({"stream": true})); merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);
@ -526,8 +530,10 @@ mod image_generation {
// ====================================== // ======================================
// Hyperbolic Audio Generation API // Hyperbolic Audio Generation API
// ====================================== // ======================================
use crate::providers::openai;
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
pub use audio_generation::*; pub use audio_generation::*;
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
mod audio_generation { mod audio_generation {
use super::{ApiResponse, Client}; use super::{ApiResponse, Client};

View File

@ -8,8 +8,9 @@
//! //!
//! ``` //! ```
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -11,7 +11,7 @@
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);

View File

@ -39,7 +39,7 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2"); //! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ``` //! ```
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -47,7 +47,7 @@ use crate::{
extractor::ExtractorBuilder, extractor::ExtractorBuilder,
json_utils, message, json_utils, message,
message::{ImageDetail, Text}, message::{ImageDetail, Text},
Embed, OneOrMany, streaming, Embed, OneOrMany,
}; };
use async_stream::stream; use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
@ -405,8 +405,25 @@ impl completion::CompletionModel for CompletionModel {
} }
} }
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub done_reason: Option<String>,
pub total_duration: Option<u64>,
pub load_duration: Option<u64>,
pub prompt_eval_count: Option<u64>,
pub prompt_eval_duration: Option<u64>,
pub eval_count: Option<u64>,
pub eval_duration: Option<u64>,
}
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> { type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request_payload = self.create_completion_request(request)?; let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true})); merge_inplace(&mut request_payload, json!({"stream": true}));
@ -426,7 +443,7 @@ impl StreamingCompletionModel for CompletionModel {
return Err(CompletionError::ProviderError(err_text)); return Err(CompletionError::ProviderError(err_text));
} }
Ok(Box::pin(stream! { let stream = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result { let chunk = match chunk_result {
@ -456,22 +473,36 @@ impl StreamingCompletionModel for CompletionModel {
match response.message { match response.message {
Message::Assistant{ content, tool_calls, .. } => { Message::Assistant{ content, tool_calls, .. } => {
if !content.is_empty() { if !content.is_empty() {
yield Ok(StreamingChoice::Message(content)) yield Ok(RawStreamingChoice::Message(content))
} }
for tool_call in tool_calls.iter() { for tool_call in tool_calls.iter() {
let function = tool_call.function.clone(); let function = tool_call.function.clone();
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments)); yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
} }
} }
_ => { _ => {
continue; continue;
} }
} }
if response.done {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
total_duration: response.total_duration,
load_duration: response.load_duration,
prompt_eval_count: response.prompt_eval_count,
prompt_eval_duration: response.prompt_eval_duration,
eval_count: response.eval_count,
eval_duration: response.eval_duration,
done_reason: response.done_reason,
}));
}
} }
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }

View File

@ -2,14 +2,16 @@ use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils; use crate::json_utils;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::Usage;
use crate::streaming; use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use async_stream::stream; use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
use tracing::debug;
// ================================================================ // ================================================================
// OpenAI Completion Streaming API // OpenAI Completion Streaming API
@ -25,10 +27,11 @@ pub struct StreamingFunction {
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingToolCall { pub struct StreamingToolCall {
pub index: usize, pub index: usize,
pub id: Option<String>,
pub function: StreamingFunction, pub function: StreamingFunction,
} }
#[derive(Deserialize)] #[derive(Deserialize, Debug)]
struct StreamingDelta { struct StreamingDelta {
#[serde(default)] #[serde(default)]
content: Option<String>, content: Option<String>,
@ -36,23 +39,34 @@ struct StreamingDelta {
tool_calls: Vec<StreamingToolCall>, tool_calls: Vec<StreamingToolCall>,
} }
#[derive(Deserialize)] #[derive(Deserialize, Debug)]
struct StreamingChoice { struct StreamingChoice {
delta: StreamingDelta, delta: StreamingDelta,
} }
#[derive(Deserialize)] #[derive(Deserialize, Debug)]
struct StreamingCompletionResponse { struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>, choices: Vec<StreamingChoice>,
usage: Option<Usage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: Usage,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await send_compatible_streaming_request(builder).await
@ -61,7 +75,7 @@ impl StreamingCompletionModel for CompletionModel {
pub async fn send_compatible_streaming_request( pub async fn send_compatible_streaming_request(
request_builder: RequestBuilder, request_builder: RequestBuilder,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
let response = request_builder.send().await?; let response = request_builder.send().await?;
if !response.status().is_success() { if !response.status().is_success() {
@ -73,11 +87,16 @@ pub async fn send_compatible_streaming_request(
} }
// Handle OpenAI Compatible SSE chunks // Handle OpenAI Compatible SSE chunks
Ok(Box::pin(stream! { let inner = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
let mut final_usage = Usage {
prompt_tokens: 0,
total_tokens: 0
};
let mut partial_data = None; let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new(); let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result { let chunk = match chunk_result {
@ -100,8 +119,6 @@ pub async fn send_compatible_streaming_request(
for line in text.lines() { for line in text.lines() {
let mut line = line.to_string(); let mut line = line.to_string();
// If there was a remaining part, concat with current line // If there was a remaining part, concat with current line
if partial_data.is_some() { if partial_data.is_some() {
line = format!("{}{}", partial_data.unwrap(), line); line = format!("{}{}", partial_data.unwrap(), line);
@ -121,64 +138,85 @@ pub async fn send_compatible_streaming_request(
} }
} }
let data = serde_json::from_str::<StreamingCompletionResponse>(&line); let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
let Ok(data) = data else { let Ok(data) = data else {
let err = data.unwrap_err();
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
continue; continue;
}; };
let choice = data.choices.first().expect("Should have at least one choice");
let delta = &choice.delta; if let Some(choice) = data.choices.first() {
if !delta.tool_calls.is_empty() { let delta = &choice.delta;
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
// Start of tool call if !delta.tool_calls.is_empty() {
// name: Some(String) for tool_call in &delta.tool_calls {
// arguments: None let function = tool_call.function.clone();
if function.name.is_some() && function.arguments.is_empty() { // Start of tool call
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string())); // name: Some(String)
// arguments: None
if function.name.is_some() && function.arguments.is_empty() {
let id = tool_call.id.clone().unwrap_or("".to_string());
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
debug!("Partial tool call received but tool call was never started.");
continue;
};
let new_arguments = &tool_call.function.arguments;
let arguments = format!("{}{}", arguments, new_arguments);
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
}
// Entire tool call
else {
let id = tool_call.id.clone().unwrap_or("".to_string());
let name = function.name.expect("function name should be present for complete tool call");
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
debug!("Couldn't serialize '{}' as a json value", arguments);
continue;
};
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
}
} }
// Part of tool call }
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
continue;
};
let new_arguments = &tool_call.function.arguments; if let Some(content) = &choice.delta.content {
let arguments = format!("{}{}", arguments, new_arguments); yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
calls.insert(tool_call.index, (name.clone(), arguments));
}
// Entire tool call
else {
let name = function.name.unwrap();
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
}
} }
} }
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone())) if let Some(usage) = data.usage {
final_usage = usage.clone();
} }
} }
} }
for (_, (name, arguments)) in calls { for (_, (id, name, arguments)) in calls {
let Ok(arguments) = serde_json::from_str(&arguments) else { let Ok(arguments) = serde_json::from_str(&arguments) else {
continue; continue;
}; };
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments)) println!("{id} {name}");
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
} }
}))
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: final_usage.clone()
}))
});
Ok(streaming::StreamingCompletionResponse::new(inner))
} }

View File

@ -0,0 +1,125 @@
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::completion::CompletionModel;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub(crate) message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

View File

@ -1,147 +1,16 @@
//! OpenRouter Inference API client and Rig integration use serde::Deserialize;
//!
//! # Example use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
use crate::{ use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils, json_utils,
providers::openai::Message, providers::openai::Message,
OneOrMany, OneOrMany,
}; };
use schemars::JsonSchema; use serde_json::{json, Value};
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::openai::AssistantContent; use crate::providers::openai::AssistantContent;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
// ================================================================ // ================================================================
// OpenRouter Completion API // OpenRouter Completion API
@ -241,7 +110,7 @@ pub struct Choice {
#[derive(Clone)] #[derive(Clone)]
pub struct CompletionModel { pub struct CompletionModel {
client: Client, pub(crate) client: Client,
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
pub model: String, pub model: String,
} }
@ -253,16 +122,11 @@ impl CompletionModel {
model: model.to_string(), model: model.to_string(),
} }
} }
}
impl completion::CompletionModel for CompletionModel { pub(crate) fn create_completion_request(
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> { ) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available) // Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble { let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)], Some(preamble) => vec![Message::system(preamble)],
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
"temperature": completion_request.temperature, "temperature": completion_request.temperature,
}); });
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
};
Ok(request)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self let response = self
.client .client
.post("/chat/completions") .post("/chat/completions")
.json( .json(&request)
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send() .send()
.await?; .await?;

View File

@ -0,0 +1,17 @@
//! OpenRouter Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
pub mod client;
pub mod completion;
pub mod streaming;
pub use client::*;
pub use completion::*;

View File

@ -0,0 +1,313 @@
use std::collections::HashMap;
use crate::{
json_utils,
message::{ToolCall, ToolFunction},
streaming::{self},
};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use crate::{
completion::{CompletionError, CompletionRequest},
streaming::StreamingCompletionModel,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingCompletionResponse {
pub id: String,
pub choices: Vec<StreamingChoice>,
pub created: u64,
pub model: String,
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponseUsage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Value>,
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<MessageResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<DeltaResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ErrorResponse>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageResponse {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Value>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolFunction {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolCall {
pub index: usize,
pub id: Option<String>,
pub r#type: Option<String>,
pub function: OpenRouterToolFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ResponseUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, Value>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DeltaResponse {
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
}
#[derive(Clone)]
pub struct FinalCompletionResponse {
pub usage: ResponseUsage,
}
impl StreamingCompletionModel for super::CompletionModel {
type StreamingResponse = FinalCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(completion_request)?;
let request = json_utils::merge(request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_streaming_request(builder).await
}
}
pub async fn send_streaming_request(
request_builder: RequestBuilder,
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
// Handle OpenAI Compatible SSE chunks
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut tool_calls = HashMap::new();
let mut partial_line = String::new();
let mut final_usage = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let mut line = line.to_string();
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
continue;
}
// Handle data: prefix
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
// If line starts with { but doesn't end with }, it's a partial JSON
if line.starts_with('{') && !line.ends_with('}') {
partial_line = line;
continue;
}
// If we have a partial line and this line ends with }, complete it
if !partial_line.is_empty() {
if line.ends_with('}') {
partial_line.push_str(&line);
line = partial_line;
partial_line = String::new();
} else {
partial_line.push_str(&line);
continue;
}
}
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
Ok(data) => data,
Err(_) => {
continue;
}
};
let choice = data.choices.first().expect("Should have at least one choice");
// TODO this has to handle outputs like this:
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
if let Some(delta) = &choice.delta {
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let index = tool_call.index;
// Get or create tool call entry
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
id: String::new(),
function: ToolFunction {
name: String::new(),
arguments: serde_json::Value::Null,
},
});
// Update fields if present
if let Some(id) = &tool_call.id {
if !id.is_empty() {
existing_tool_call.id = id.clone();
}
}
if let Some(name) = &tool_call.function.name {
if !name.is_empty() {
existing_tool_call.function.name = name.clone();
}
}
if let Some(chunk) = &tool_call.function.arguments {
// Convert current arguments to string if needed
let current_args = match &existing_tool_call.function.arguments {
serde_json::Value::Null => String::new(),
serde_json::Value::String(s) => s.clone(),
v => v.to_string(),
};
// Concatenate the new chunk
let combined = format!("{}{}", current_args, chunk);
// Try to parse as JSON if it looks complete
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
match serde_json::from_str(&combined) {
Ok(parsed) => existing_tool_call.function.arguments = parsed,
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
}
} else {
existing_tool_call.function.arguments = serde_json::Value::String(combined);
}
}
}
}
if let Some(content) = &delta.content {
if !content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
}
if let Some(usage) = data.usage {
final_usage = Some(usage);
}
}
// Handle message format
if let Some(message) = &choice.message {
if !message.tool_calls.is_empty() {
for tool_call in &message.tool_calls {
let name = tool_call.function.name.clone();
let id = tool_call.id.clone();
let arguments = if let Some(args) = &tool_call.function.arguments {
// Try to parse the string as JSON, fallback to string value
match serde_json::from_str(args) {
Ok(v) => v,
Err(_) => serde_json::Value::String(args.to_string()),
}
} else {
serde_json::Value::Null
};
let index = tool_call.index;
tool_calls.insert(index, ToolCall{
id: id.unwrap_or_default(),
function: ToolFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
}
if !message.content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
}
}
}
}
for (_, tool_call) in tool_calls.into_iter() {
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
}
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
usage: final_usage.unwrap_or_default()
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}

View File

@ -18,8 +18,9 @@ use crate::{
use crate::completion::CompletionRequest; use crate::completion::CompletionRequest;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: completion::CompletionRequest, completion_request: completion::CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -1,18 +1,21 @@
use serde_json::json; use serde_json::json;
use super::completion::CompletionModel; use super::completion::CompletionModel;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::StreamingCompletionResponse;
use crate::{ use crate::{
completion::{CompletionError, CompletionRequest}, completion::{CompletionError, CompletionRequest},
json_utils::merge, json_utils::merge,
streaming::{StreamingCompletionModel, StreamingResult}, streaming::StreamingCompletionModel,
}; };
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream_tokens": true})); request = merge(request, json!({"stream_tokens": true}));

View File

@ -1,15 +1,17 @@
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::providers::xai::completion::CompletionModel; use crate::providers::xai::completion::CompletionModel;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use serde_json::json; use serde_json::json;
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -11,59 +11,150 @@
use crate::agent::Agent; use crate::agent::Agent;
use crate::completion::{ use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
CompletionResponse, Message,
}; };
use crate::message::{AssistantContent, ToolCall, ToolFunction};
use crate::OneOrMany;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::boxed::Box; use std::boxed::Box;
use std::fmt::{Display, Formatter};
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll};
/// Enum representing a streaming chunk from the model /// Enum representing a streaming chunk from the model
#[derive(Debug)] #[derive(Debug, Clone)]
pub enum StreamingChoice { pub enum RawStreamingChoice<R: Clone> {
/// A text chunk from a message response /// A text chunk from a message response
Message(String), Message(String),
/// A tool call response chunk /// A tool call response chunk
ToolCall(String, String, serde_json::Value), ToolCall(String, String, serde_json::Value),
/// The final response object, must be yielded if you want the
/// `response` field to be populated on the `StreamingCompletionResponse`
FinalResponse(R),
} }
impl Display for StreamingChoice { #[cfg(not(target_arch = "wasm32"))]
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { pub type StreamingResult<R> =
match self { Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
StreamingChoice::Message(text) => write!(f, "{}", text),
StreamingChoice::ToolCall(name, id, params) => { #[cfg(target_arch = "wasm32")]
write!(f, "Tool call: {} {} {:?}", name, id, params) pub type StreamingResult<R> =
} Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
/// The response from a streaming completion request;
/// message and response are populated at the end of the
/// `inner` stream.
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
inner: StreamingResult<R>,
text: String,
tool_calls: Vec<ToolCall>,
/// The final aggregated message from the stream
/// contains all text and tool calls generated
pub choice: OneOrMany<AssistantContent>,
/// The final response from the stream, may be `None`
/// if the provider didn't yield it during the stream
pub response: Option<R>,
}
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
Self {
inner,
text: "".to_string(),
tool_calls: vec![],
choice: OneOrMany::one(AssistantContent::text("")),
response: None,
} }
} }
} }
#[cfg(not(target_arch = "wasm32"))] impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
pub type StreamingResult = fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>; CompletionResponse {
choice: value.choice,
raw_response: value.response,
}
}
}
#[cfg(target_arch = "wasm32")] impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>; type Item = Result<AssistantContent, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
match stream.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
// This is run at the end of the inner stream to collect all tokens into
// a single unified `Message`.
let mut choice = vec![];
stream.tool_calls.iter().for_each(|tc| {
choice.push(AssistantContent::ToolCall(tc.clone()));
});
// This is required to ensure there's always at least one item in the content
if choice.is_empty() || !stream.text.is_empty() {
choice.insert(0, AssistantContent::text(stream.text.clone()));
}
stream.choice = OneOrMany::many(choice)
.expect("There should be at least one assistant message");
Poll::Ready(None)
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => {
// Forward the streaming tokens to the outer stream
// and concat the text together
stream.text = format!("{}{}", stream.text, text.clone());
Poll::Ready(Some(Ok(AssistantContent::text(text))))
}
RawStreamingChoice::ToolCall(id, name, args) => {
// Keep track of each tool call to aggregate the final message later
// and pass it to the outer stream
stream.tool_calls.push(ToolCall {
id: id.clone(),
function: ToolFunction {
name: name.clone(),
arguments: args.clone(),
},
});
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
}
RawStreamingChoice::FinalResponse(response) => {
// Set the final response field and return the next item in the stream
stream.response = Some(response);
stream.poll_next_unpin(cx)
}
},
}
}
}
/// Trait for high-level streaming prompt interface /// Trait for high-level streaming prompt interface
pub trait StreamingPrompt: Send + Sync { pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
/// Stream a simple prompt to the model /// Stream a simple prompt to the model
fn stream_prompt( fn stream_prompt(
&self, &self,
prompt: &str, prompt: &str,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
} }
/// Trait for high-level streaming chat interface /// Trait for high-level streaming chat interface
pub trait StreamingChat: Send + Sync { pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
/// Stream a chat with history to the model /// Stream a chat with history to the model
fn stream_chat( fn stream_chat(
&self, &self,
prompt: &str, prompt: &str,
chat_history: Vec<Message>, chat_history: Vec<Message>,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
} }
/// Trait for low-level streaming completion interface /// Trait for low-level streaming completion interface
@ -78,29 +169,35 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
/// Trait defining a streaming completion model /// Trait defining a streaming completion model
pub trait StreamingCompletionModel: CompletionModel { pub trait StreamingCompletionModel: CompletionModel {
type StreamingResponse: Clone + Unpin;
/// Stream a completion response for the given request /// Stream a completion response for the given request
fn stream( fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>; ) -> impl Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
>;
} }
/// helper function to stream a completion request to stdout /// helper function to stream a completion request to stdout
pub async fn stream_to_stdout<M: StreamingCompletionModel>( pub async fn stream_to_stdout<M: StreamingCompletionModel>(
agent: Agent<M>, agent: Agent<M>,
stream: &mut StreamingResult, stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
) -> Result<(), std::io::Error> { ) -> Result<(), std::io::Error> {
print!("Response: "); print!("Response: ");
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
match chunk { match chunk {
Ok(StreamingChoice::Message(text)) => { Ok(AssistantContent::Text(text)) => {
print!("{}", text); print!("{}", text.text);
std::io::Write::flush(&mut std::io::stdout())?; std::io::Write::flush(&mut std::io::stdout())?;
} }
Ok(StreamingChoice::ToolCall(name, _, params)) => { Ok(AssistantContent::ToolCall(tool_call)) => {
let res = agent let res = agent
.tools .tools
.call(&name, params.to_string()) .call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await .await
.map_err(|e| std::io::Error::other(e.to_string()))?; .map_err(|e| std::io::Error::other(e.to_string()))?;
println!("\nResult: {}", res); println!("\nResult: {}", res);
@ -111,6 +208,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
} }
} }
} }
println!(); // New line after streaming completes println!(); // New line after streaming completes
Ok(()) Ok(())