This commit is contained in:
yavens 2025-04-18 10:12:21 +07:00 committed by GitHub
commit d3f6857019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1437 additions and 317 deletions

View File

@ -19,5 +19,11 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -0,0 +1,27 @@
use rig::providers::cohere;
use rig::streaming::{stream_to_stdout, StreamingPrompt};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = cohere::Client::from_env()
.agent(cohere::COMMAND)
.preamble("Be precise and concise.")
.temperature(0.5)
.build();
// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await?;
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -0,0 +1,118 @@
use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::cohere::Client::from_env()
.agent(providers::cohere::COMMAND_R)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!(
"Usage: {:?} tokens",
response.usage_metadata.total_token_count
);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -107,5 +107,15 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!(
"Usage: {:?} tokens",
response.usage_metadata.total_token_count
);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -17,5 +17,11 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -0,0 +1,118 @@
use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::openrouter::Client::from_env()
.agent(providers::openrouter::GEMINI_FLASH_2_0)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.choice);
Ok(())
}

View File

@ -110,23 +110,20 @@ use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::streaming::StreamingCompletionResponse;
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
async fn stream_prompt(
&self,
prompt: &str,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()

View File

@ -67,7 +67,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::OneOrMany;
use crate::{
json_utils,
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
/// Stream the completion request
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
pub async fn stream(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
let model = self.model.clone();
model.stream(self.build()).await
}

View File

@ -8,7 +8,8 @@ use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace;
use crate::message::MessageError;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
use crate::streaming;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
@ -61,7 +62,7 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone)]
pub struct PartialUsage {
pub output_tokens: usize,
#[serde(default)]
@ -75,11 +76,18 @@ struct ToolCallState {
input_json: String,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: PartialUsage,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens
} else if let Some(tokens) = self.default_max_tokens {
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
// Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! {
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
let mut current_tool_call: Option<ToolCallState> = None;
let mut sse_stream = Box::pin(sse_stream);
let mut input_tokens = 0;
while let Some(sse_result) = sse_stream.next().await {
match sse_result {
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
// Parse the SSE data as a StreamingEvent
match serde_json::from_str::<StreamingEvent>(&sse.data) {
Ok(event) => {
match &event {
StreamingEvent::MessageStart { message } => {
input_tokens = message.usage.input_tokens;
},
StreamingEvent::MessageDelta { delta, usage } => {
if delta.stop_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: PartialUsage {
output_tokens: usage.output_tokens,
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
}
}))
}
}
_ => {}
}
if let Some(result) = handle_event(&event, &mut current_tool_call) {
yield result;
}
@ -184,19 +211,21 @@ impl StreamingCompletionModel for CompletionModel {
}
}
}
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}
fn handle_event(
event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>,
) -> Option<Result<StreamingChoice, CompletionError>> {
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
match event {
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => {
if current_tool_call.is_none() {
return Some(Ok(StreamingChoice::Message(text.clone())));
return Some(Ok(RawStreamingChoice::Message(text.clone())));
}
None
}
@ -227,7 +256,7 @@ fn handle_event(
&tool_call.input_json
};
match serde_json::from_str(json_str) {
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
tool_call.name,
tool_call.id,
json_value,

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
// Azure OpenAI Streaming API
// -----------------------------------------------------
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self
.client

View File

@ -6,8 +6,9 @@ use crate::{
};
use super::client::Client;
use crate::completion::CompletionRequest;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub(crate) client: Client,
pub model: String,
}
@ -430,16 +431,11 @@ impl CompletionModel {
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
pub(crate) fn create_completion_request(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
let prompt = completion_request.prompt_with_context();
let mut messages: Vec<message::Message> =
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
});
if let Some(ref params) = completion_request.additional_params {
Ok(json_utils::merge(request.clone(), params.clone()))
} else {
Ok(request)
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self
.client
.post("/v2/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
let response = self.client.post("/v2/chat").json(&request).send().await?;
if response.status().is_success() {
let text_response = response.text().await?;

View File

@ -12,6 +12,7 @@
pub mod client;
pub mod completion;
pub mod embeddings;
pub mod streaming;
pub use client::Client;
pub use client::{ApiErrorResponse, ApiResponse};
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
// ================================================================
/// `command-r-plus` completion model
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
pub const COMMAND_R_PLUS: &str = "command-r-plus";
/// `command-r` completion model
pub const COMMAND_R: &str = "command-r";
/// `command` completion model

View File

@ -0,0 +1,188 @@
use crate::completion::{CompletionError, CompletionRequest};
use crate::providers::cohere::completion::Usage;
use crate::providers::cohere::CompletionModel;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{json_utils, streaming};
use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case", tag = "type")]
enum StreamingEvent {
MessageStart,
ContentStart,
ContentDelta { delta: Option<Delta> },
ContentEnd,
ToolPlan,
ToolCallStart { delta: Option<Delta> },
ToolCallDelta { delta: Option<Delta> },
ToolCallEnd,
MessageEnd { delta: Option<MessageEndDelta> },
}
#[derive(Debug, Deserialize)]
struct MessageContentDelta {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolFunctionDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolCallDelta {
id: Option<String>,
function: Option<MessageToolFunctionDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageDelta {
content: Option<MessageContentDelta>,
tool_calls: Option<MessageToolCallDelta>,
}
#[derive(Debug, Deserialize)]
struct Delta {
message: Option<MessageDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageEndDelta {
usage: Option<Usage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: Option<Usage>,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(request)?;
let request = json_utils::merge(request, json!({"stream": true}));
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self.client.post("/v2/chat").json(&request).send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut current_tool_call: Option<(String, String, String)> = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let Some(line) = line.strip_prefix("data: ") else {
continue;
};
let event = {
let result = serde_json::from_str::<StreamingEvent>(line);
let Ok(event) = result else {
continue;
};
event
};
match event {
StreamingEvent::ContentDelta { delta: Some(delta) } => {
let Some(message) = &delta.message else { continue; };
let Some(content) = &message.content else { continue; };
let Some(text) = &content.text else { continue; };
yield Ok(RawStreamingChoice::Message(text.clone()));
},
StreamingEvent::MessageEnd {delta: Some(delta)} => {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: delta.usage.clone()
}));
},
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(id) = tool_calls.id.clone() else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(name) = function.name.clone() else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
current_tool_call = Some((id, name, arguments));
},
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
if let Some(tc) = current_tool_call.clone() {
current_tool_call = Some((
tc.0,
tc.1,
format!("{}{}", tc.2, arguments)
));
};
},
StreamingEvent::ToolCallEnd => {
let Some(tc) = current_tool_call.clone() else { continue; };
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
yield Ok(RawStreamingChoice::ToolCall(
tc.0,
tc.1,
args
));
current_tool_call = None;
},
_ => {}
};
}
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -10,8 +10,9 @@
//! ```
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder,
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
}
impl StreamingCompletionModel for DeepSeekCompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/v1/chat/completions").json(&request);
send_compatible_streaming_request(builder).await

View File

@ -13,7 +13,7 @@
use super::openai;
use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request);

View File

@ -609,7 +609,7 @@ pub mod gemini_api_types {
HarmCategoryCivicIntegrity,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub prompt_token_count: i32,

View File

@ -2,12 +2,17 @@ use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
use crate::{
completion::{CompletionError, CompletionRequest},
streaming::{self, StreamingCompletionModel, StreamingResult},
streaming::{self, StreamingCompletionModel},
};
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
#[derive(Debug, Deserialize, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PartialUsage {
pub total_token_count: i32,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
@ -15,13 +20,21 @@ pub struct StreamGenerateContentResponse {
/// Candidate responses from the model.
pub candidates: Vec<ContentCandidate>,
pub model_version: Option<String>,
pub usage_metadata: Option<PartialUsage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage_metadata: PartialUsage,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = create_request_body(completion_request)?;
let response = self
@ -42,7 +55,7 @@ impl StreamingCompletionModel for CompletionModel {
)));
}
Ok(Box::pin(stream! {
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
@ -74,13 +87,23 @@ impl StreamingCompletionModel for CompletionModel {
match choice.content.parts.first() {
super::completion::gemini_api_types::Part::Text(text)
=> yield Ok(streaming::StreamingChoice::Message(text)),
=> yield Ok(streaming::RawStreamingChoice::Message(text)),
super::completion::gemini_api_types::Part::FunctionCall(function_call)
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
_ => panic!("Unsupported response type with streaming.")
};
}
if choice.finish_reason.is_some() {
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage_metadata: PartialUsage {
total_token_count: data.usage_metadata.unwrap().total_token_count,
}
}))
}
}
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -10,7 +10,8 @@
//! ```
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::providers::openai;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request);

View File

@ -1,9 +1,9 @@
use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge_inplace;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
use crate::streaming::StreamingCompletionModel;
use crate::{json_utils, streaming};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::convert::Infallible;
@ -55,14 +55,19 @@ struct CompletionChunk {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_request_body(&completion_request)?;
// Enable streaming
merge_inplace(&mut request, json!({"stream": true}));
merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone());

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
merge_inplace(&mut request, json!({"stream": true}));
merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request);
@ -526,8 +530,10 @@ mod image_generation {
// ======================================
// Hyperbolic Audio Generation API
// ======================================
use crate::providers::openai;
#[cfg(feature = "audio")]
pub use audio_generation::*;
#[cfg(feature = "audio")]
mod audio_generation {
use super::{ApiResponse, Client};

View File

@ -8,8 +8,9 @@
//!
//! ```
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -11,7 +11,7 @@
use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request);

View File

@ -39,7 +39,7 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ```
use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -47,7 +47,7 @@ use crate::{
extractor::ExtractorBuilder,
json_utils, message,
message::{ImageDetail, Text},
Embed, OneOrMany,
streaming, Embed, OneOrMany,
};
use async_stream::stream;
use futures::StreamExt;
@ -405,8 +405,25 @@ impl completion::CompletionModel for CompletionModel {
}
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub done_reason: Option<String>,
pub total_duration: Option<u64>,
pub load_duration: Option<u64>,
pub prompt_eval_count: Option<u64>,
pub prompt_eval_duration: Option<u64>,
pub eval_count: Option<u64>,
pub eval_duration: Option<u64>,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true}));
@ -426,7 +443,7 @@ impl StreamingCompletionModel for CompletionModel {
return Err(CompletionError::ProviderError(err_text));
}
Ok(Box::pin(stream! {
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
@ -456,22 +473,36 @@ impl StreamingCompletionModel for CompletionModel {
match response.message {
Message::Assistant{ content, tool_calls, .. } => {
if !content.is_empty() {
yield Ok(StreamingChoice::Message(content))
yield Ok(RawStreamingChoice::Message(content))
}
for tool_call in tool_calls.iter() {
let function = tool_call.function.clone();
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
}
}
_ => {
continue;
}
}
if response.done {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
total_duration: response.total_duration,
load_duration: response.load_duration,
prompt_eval_count: response.prompt_eval_count,
prompt_eval_duration: response.prompt_eval_duration,
eval_count: response.eval_count,
eval_duration: response.eval_duration,
done_reason: response.done_reason,
}));
}
}
}))
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -2,14 +2,16 @@ use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge;
use crate::providers::openai::Usage;
use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use tracing::debug;
// ================================================================
// OpenAI Completion Streaming API
@ -25,10 +27,11 @@ pub struct StreamingFunction {
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingToolCall {
pub index: usize,
pub id: Option<String>,
pub function: StreamingFunction,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
struct StreamingDelta {
#[serde(default)]
content: Option<String>,
@ -36,23 +39,34 @@ struct StreamingDelta {
tool_calls: Vec<StreamingToolCall>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
struct StreamingChoice {
delta: StreamingDelta,
}
#[derive(Deserialize)]
struct StreamingCompletionResponse {
#[derive(Deserialize, Debug)]
struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>,
usage: Option<Usage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: Usage,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));
request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await
@ -61,7 +75,7 @@ impl StreamingCompletionModel for CompletionModel {
pub async fn send_compatible_streaming_request(
request_builder: RequestBuilder,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
@ -73,11 +87,16 @@ pub async fn send_compatible_streaming_request(
}
// Handle OpenAI Compatible SSE chunks
Ok(Box::pin(stream! {
let inner = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut final_usage = Usage {
prompt_tokens: 0,
total_tokens: 0
};
let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
@ -100,8 +119,6 @@ pub async fn send_compatible_streaming_request(
for line in text.lines() {
let mut line = line.to_string();
// If there was a remaining part, concat with current line
if partial_data.is_some() {
line = format!("{}{}", partial_data.unwrap(), line);
@ -121,64 +138,85 @@ pub async fn send_compatible_streaming_request(
}
}
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
let Ok(data) = data else {
let err = data.unwrap_err();
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
continue;
};
let choice = data.choices.first().expect("Should have at least one choice");
if let Some(choice) = data.choices.first() {
let delta = &choice.delta;
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
// Start of tool call
// name: Some(String)
// arguments: None
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
let id = tool_call.id.clone().unwrap_or("".to_string());
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
debug!("Partial tool call received but tool call was never started.");
continue;
};
let new_arguments = &tool_call.function.arguments;
let arguments = format!("{}{}", arguments, new_arguments);
calls.insert(tool_call.index, (name.clone(), arguments));
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
}
// Entire tool call
else {
let name = function.name.unwrap();
let id = tool_call.id.clone().unwrap_or("".to_string());
let name = function.name.expect("function name should be present for complete tool call");
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
debug!("Couldn't serialize '{}' as a json value", arguments);
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
}
}
}
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone()))
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
}
if let Some(usage) = data.usage {
final_usage = usage.clone();
}
}
}
for (_, (name, arguments)) in calls {
for (_, (id, name, arguments)) in calls {
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
println!("{id} {name}");
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
}
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: final_usage.clone()
}))
});
Ok(streaming::StreamingCompletionResponse::new(inner))
}

View File

@ -0,0 +1,125 @@
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::completion::CompletionModel;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub(crate) message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

View File

@ -1,147 +1,16 @@
//! OpenRouter Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
use serde::Deserialize;
use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
providers::openai::Message,
OneOrMany,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
use super::openai::AssistantContent;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
use crate::providers::openai::AssistantContent;
// ================================================================
// OpenRouter Completion API
@ -241,7 +110,7 @@ pub struct Choice {
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub(crate) client: Client,
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
pub model: String,
}
@ -253,16 +122,11 @@ impl CompletionModel {
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
pub(crate) fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
"temperature": completion_request.temperature,
});
let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
};
Ok(request)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;

View File

@ -0,0 +1,17 @@
//! OpenRouter Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
pub mod client;
pub mod completion;
pub mod streaming;
pub use client::*;
pub use completion::*;

View File

@ -0,0 +1,313 @@
use std::collections::HashMap;
use crate::{
json_utils,
message::{ToolCall, ToolFunction},
streaming::{self},
};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use crate::{
completion::{CompletionError, CompletionRequest},
streaming::StreamingCompletionModel,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingCompletionResponse {
pub id: String,
pub choices: Vec<StreamingChoice>,
pub created: u64,
pub model: String,
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponseUsage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Value>,
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<MessageResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<DeltaResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ErrorResponse>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageResponse {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Value>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolFunction {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolCall {
pub index: usize,
pub id: Option<String>,
pub r#type: Option<String>,
pub function: OpenRouterToolFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ResponseUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, Value>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DeltaResponse {
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
}
#[derive(Clone)]
pub struct FinalCompletionResponse {
pub usage: ResponseUsage,
}
impl StreamingCompletionModel for super::CompletionModel {
type StreamingResponse = FinalCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(completion_request)?;
let request = json_utils::merge(request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_streaming_request(builder).await
}
}
pub async fn send_streaming_request(
request_builder: RequestBuilder,
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
// Handle OpenAI Compatible SSE chunks
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut tool_calls = HashMap::new();
let mut partial_line = String::new();
let mut final_usage = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let mut line = line.to_string();
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
continue;
}
// Handle data: prefix
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
// If line starts with { but doesn't end with }, it's a partial JSON
if line.starts_with('{') && !line.ends_with('}') {
partial_line = line;
continue;
}
// If we have a partial line and this line ends with }, complete it
if !partial_line.is_empty() {
if line.ends_with('}') {
partial_line.push_str(&line);
line = partial_line;
partial_line = String::new();
} else {
partial_line.push_str(&line);
continue;
}
}
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
Ok(data) => data,
Err(_) => {
continue;
}
};
let choice = data.choices.first().expect("Should have at least one choice");
// TODO this has to handle outputs like this:
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
if let Some(delta) = &choice.delta {
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let index = tool_call.index;
// Get or create tool call entry
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
id: String::new(),
function: ToolFunction {
name: String::new(),
arguments: serde_json::Value::Null,
},
});
// Update fields if present
if let Some(id) = &tool_call.id {
if !id.is_empty() {
existing_tool_call.id = id.clone();
}
}
if let Some(name) = &tool_call.function.name {
if !name.is_empty() {
existing_tool_call.function.name = name.clone();
}
}
if let Some(chunk) = &tool_call.function.arguments {
// Convert current arguments to string if needed
let current_args = match &existing_tool_call.function.arguments {
serde_json::Value::Null => String::new(),
serde_json::Value::String(s) => s.clone(),
v => v.to_string(),
};
// Concatenate the new chunk
let combined = format!("{}{}", current_args, chunk);
// Try to parse as JSON if it looks complete
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
match serde_json::from_str(&combined) {
Ok(parsed) => existing_tool_call.function.arguments = parsed,
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
}
} else {
existing_tool_call.function.arguments = serde_json::Value::String(combined);
}
}
}
}
if let Some(content) = &delta.content {
if !content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
}
if let Some(usage) = data.usage {
final_usage = Some(usage);
}
}
// Handle message format
if let Some(message) = &choice.message {
if !message.tool_calls.is_empty() {
for tool_call in &message.tool_calls {
let name = tool_call.function.name.clone();
let id = tool_call.id.clone();
let arguments = if let Some(args) = &tool_call.function.arguments {
// Try to parse the string as JSON, fallback to string value
match serde_json::from_str(args) {
Ok(v) => v,
Err(_) => serde_json::Value::String(args.to_string()),
}
} else {
serde_json::Value::Null
};
let index = tool_call.index;
tool_calls.insert(index, ToolCall{
id: id.unwrap_or_default(),
function: ToolFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
}
if !message.content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
}
}
}
}
for (_, tool_call) in tool_calls.into_iter() {
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
}
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
usage: final_usage.unwrap_or_default()
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}

View File

@ -18,8 +18,9 @@ use crate::{
use crate::completion::CompletionRequest;
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: completion::CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -1,18 +1,21 @@
use serde_json::json;
use super::completion::CompletionModel;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::StreamingCompletionResponse;
use crate::{
completion::{CompletionError, CompletionRequest},
json_utils::merge,
streaming::{StreamingCompletionModel, StreamingResult},
streaming::StreamingCompletionModel,
};
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream_tokens": true}));

View File

@ -1,15 +1,17 @@
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::providers::xai::completion::CompletionModel;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use serde_json::json;
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -11,59 +11,150 @@
use crate::agent::Agent;
use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
CompletionResponse, Message,
};
use crate::message::{AssistantContent, ToolCall, ToolFunction};
use crate::OneOrMany;
use futures::{Stream, StreamExt};
use std::boxed::Box;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
/// Enum representing a streaming chunk from the model
#[derive(Debug)]
pub enum StreamingChoice {
#[derive(Debug, Clone)]
pub enum RawStreamingChoice<R: Clone> {
/// A text chunk from a message response
Message(String),
/// A tool call response chunk
ToolCall(String, String, serde_json::Value),
}
impl Display for StreamingChoice {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
StreamingChoice::Message(text) => write!(f, "{}", text),
StreamingChoice::ToolCall(name, id, params) => {
write!(f, "Tool call: {} {} {:?}", name, id, params)
}
}
}
/// The final response object, must be yielded if you want the
/// `response` field to be populated on the `StreamingCompletionResponse`
FinalResponse(R),
}
#[cfg(not(target_arch = "wasm32"))]
pub type StreamingResult =
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
#[cfg(target_arch = "wasm32")]
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
/// The response from a streaming completion request;
/// message and response are populated at the end of the
/// `inner` stream.
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
inner: StreamingResult<R>,
text: String,
tool_calls: Vec<ToolCall>,
/// The final aggregated message from the stream
/// contains all text and tool calls generated
pub choice: OneOrMany<AssistantContent>,
/// The final response from the stream, may be `None`
/// if the provider didn't yield it during the stream
pub response: Option<R>,
}
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
Self {
inner,
text: "".to_string(),
tool_calls: vec![],
choice: OneOrMany::one(AssistantContent::text("")),
response: None,
}
}
}
impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
CompletionResponse {
choice: value.choice,
raw_response: value.response,
}
}
}
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
type Item = Result<AssistantContent, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
match stream.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
// This is run at the end of the inner stream to collect all tokens into
// a single unified `Message`.
let mut choice = vec![];
stream.tool_calls.iter().for_each(|tc| {
choice.push(AssistantContent::ToolCall(tc.clone()));
});
// This is required to ensure there's always at least one item in the content
if choice.is_empty() || !stream.text.is_empty() {
choice.insert(0, AssistantContent::text(stream.text.clone()));
}
stream.choice = OneOrMany::many(choice)
.expect("There should be at least one assistant message");
Poll::Ready(None)
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => {
// Forward the streaming tokens to the outer stream
// and concat the text together
stream.text = format!("{}{}", stream.text, text.clone());
Poll::Ready(Some(Ok(AssistantContent::text(text))))
}
RawStreamingChoice::ToolCall(id, name, args) => {
// Keep track of each tool call to aggregate the final message later
// and pass it to the outer stream
stream.tool_calls.push(ToolCall {
id: id.clone(),
function: ToolFunction {
name: name.clone(),
arguments: args.clone(),
},
});
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
}
RawStreamingChoice::FinalResponse(response) => {
// Set the final response field and return the next item in the stream
stream.response = Some(response);
stream.poll_next_unpin(cx)
}
},
}
}
}
/// Trait for high-level streaming prompt interface
pub trait StreamingPrompt: Send + Sync {
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
/// Stream a simple prompt to the model
fn stream_prompt(
&self,
prompt: &str,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
}
/// Trait for high-level streaming chat interface
pub trait StreamingChat: Send + Sync {
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
/// Stream a chat with history to the model
fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
}
/// Trait for low-level streaming completion interface
@ -78,29 +169,35 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
/// Trait defining a streaming completion model
pub trait StreamingCompletionModel: CompletionModel {
type StreamingResponse: Clone + Unpin;
/// Stream a completion response for the given request
fn stream(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
) -> impl Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
>;
}
/// helper function to stream a completion request to stdout
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
agent: Agent<M>,
stream: &mut StreamingResult,
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
) -> Result<(), std::io::Error> {
print!("Response: ");
while let Some(chunk) = stream.next().await {
match chunk {
Ok(StreamingChoice::Message(text)) => {
print!("{}", text);
Ok(AssistantContent::Text(text)) => {
print!("{}", text.text);
std::io::Write::flush(&mut std::io::stdout())?;
}
Ok(StreamingChoice::ToolCall(name, _, params)) => {
Ok(AssistantContent::ToolCall(tool_call)) => {
let res = agent
.tools
.call(&name, params.to_string())
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await
.map_err(|e| std::io::Error::other(e.to_string()))?;
println!("\nResult: {}", res);
@ -111,6 +208,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
}
}
}
println!(); // New line after streaming completes
Ok(())