mirror of https://github.com/0xplaygrounds/rig
Merge a78969fd9a
into 33e8fc7a65
This commit is contained in:
commit
d3f6857019
|
@ -19,5 +19,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
use rig::providers::cohere;
|
||||||
|
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
// Create streaming agent with a single context prompt
|
||||||
|
let agent = cohere::Client::from_env()
|
||||||
|
.agent(cohere::COMMAND)
|
||||||
|
.preamble("Be precise and concise.")
|
||||||
|
.temperature(0.5)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Stream the response and print chunks as they arrive
|
||||||
|
let mut stream = agent
|
||||||
|
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use rig::streaming::stream_to_stdout;
|
||||||
|
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Adder;
|
||||||
|
impl Tool for Adder {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: "add".to_string(),
|
||||||
|
description: "Add x and y together".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Subtract;
|
||||||
|
impl Tool for Subtract {
|
||||||
|
const NAME: &'static str = "subtract";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "subtract",
|
||||||
|
"description": "Subtract y from x (i.e.: x - y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract from"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x - args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
// Create agent with a single context prompt and two tools
|
||||||
|
let calculator_agent = providers::cohere::Client::from_env()
|
||||||
|
.agent(providers::cohere::COMMAND_R)
|
||||||
|
.preamble(
|
||||||
|
"You are a calculator here to help the user perform arithmetic
|
||||||
|
operations. Use the tools provided to answer the user's question.
|
||||||
|
make your answer long, so we can test the streaming functionality,
|
||||||
|
like 20 words",
|
||||||
|
)
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Adder)
|
||||||
|
.tool(Subtract)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Calculate 2 - 5");
|
||||||
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!(
|
||||||
|
"Usage: {:?} tokens",
|
||||||
|
response.usage_metadata.total_token_count
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,15 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!(
|
||||||
|
"Usage: {:?} tokens",
|
||||||
|
response.usage_metadata.total_token_count
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.eval_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.eval_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,5 +17,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?}", response.usage)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?}", response.usage)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use rig::streaming::stream_to_stdout;
|
||||||
|
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Adder;
|
||||||
|
impl Tool for Adder {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: "add".to_string(),
|
||||||
|
description: "Add x and y together".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Subtract;
|
||||||
|
impl Tool for Subtract {
|
||||||
|
const NAME: &'static str = "subtract";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "subtract",
|
||||||
|
"description": "Subtract y from x (i.e.: x - y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract from"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x - args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
// Create agent with a single context prompt and two tools
|
||||||
|
let calculator_agent = providers::openrouter::Client::from_env()
|
||||||
|
.agent(providers::openrouter::GEMINI_FLASH_2_0)
|
||||||
|
.preamble(
|
||||||
|
"You are a calculator here to help the user perform arithmetic
|
||||||
|
operations. Use the tools provided to answer the user's question.
|
||||||
|
make your answer long, so we can test the streaming functionality,
|
||||||
|
like 20 words",
|
||||||
|
)
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Adder)
|
||||||
|
.tool(Subtract)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Calculate 2 - 5");
|
||||||
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?}", response.usage)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.choice);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -110,23 +110,20 @@ use std::collections::HashMap;
|
||||||
|
|
||||||
use futures::{stream, StreamExt, TryStreamExt};
|
use futures::{stream, StreamExt, TryStreamExt};
|
||||||
|
|
||||||
|
use crate::streaming::StreamingCompletionResponse;
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
use crate::tool::McpTool;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{
|
completion::{
|
||||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||||
Message, Prompt, PromptError,
|
Message, Prompt, PromptError,
|
||||||
},
|
},
|
||||||
message::AssistantContent,
|
message::AssistantContent,
|
||||||
streaming::{
|
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
|
||||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
|
||||||
StreamingResult,
|
|
||||||
},
|
|
||||||
tool::{Tool, ToolSet},
|
tool::{Tool, ToolSet},
|
||||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
use crate::tool::McpTool;
|
|
||||||
|
|
||||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||||
/// All context documents and tools are always provided to the agent when prompted.
|
/// All context documents and tools are always provided to the agent when prompted.
|
||||||
|
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
|
||||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
async fn stream_prompt(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
self.stream_chat(prompt, vec![]).await
|
self.stream_chat(prompt, vec![]).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
|
||||||
async fn stream_chat(
|
async fn stream_chat(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
chat_history: Vec<Message>,
|
chat_history: Vec<Message>,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
self.stream_completion(prompt, chat_history)
|
self.stream_completion(prompt, chat_history)
|
||||||
.await?
|
.await?
|
||||||
.stream()
|
.stream()
|
||||||
|
|
|
@ -67,7 +67,7 @@ use std::collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::OneOrMany;
|
use crate::OneOrMany;
|
||||||
use crate::{
|
use crate::{
|
||||||
json_utils,
|
json_utils,
|
||||||
|
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
||||||
/// Stream the completion request
|
/// Stream the completion request
|
||||||
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
|
pub async fn stream(
|
||||||
|
self,
|
||||||
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
model.stream(self.build()).await
|
model.stream(self.build()).await
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,8 @@ use super::decoders::sse::from_response as sse_from_response;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::message::MessageError;
|
use crate::message::MessageError;
|
||||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming;
|
||||||
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
@ -61,7 +62,7 @@ pub struct MessageDelta {
|
||||||
pub stop_sequence: Option<String>,
|
pub stop_sequence: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct PartialUsage {
|
pub struct PartialUsage {
|
||||||
pub output_tokens: usize,
|
pub output_tokens: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -75,11 +76,18 @@ struct ToolCallState {
|
||||||
input_json: String,
|
input_json: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage: PartialUsage,
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
||||||
tokens
|
tokens
|
||||||
} else if let Some(tokens) = self.default_max_tokens {
|
} else if let Some(tokens) = self.default_max_tokens {
|
||||||
|
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
// Use our SSE decoder to directly handle Server-Sent Events format
|
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||||
let sse_stream = sse_from_response(response);
|
let sse_stream = sse_from_response(response);
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
|
||||||
let mut current_tool_call: Option<ToolCallState> = None;
|
let mut current_tool_call: Option<ToolCallState> = None;
|
||||||
let mut sse_stream = Box::pin(sse_stream);
|
let mut sse_stream = Box::pin(sse_stream);
|
||||||
|
let mut input_tokens = 0;
|
||||||
|
|
||||||
while let Some(sse_result) = sse_stream.next().await {
|
while let Some(sse_result) = sse_stream.next().await {
|
||||||
match sse_result {
|
match sse_result {
|
||||||
|
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
// Parse the SSE data as a StreamingEvent
|
// Parse the SSE data as a StreamingEvent
|
||||||
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
|
match &event {
|
||||||
|
StreamingEvent::MessageStart { message } => {
|
||||||
|
input_tokens = message.usage.input_tokens;
|
||||||
|
},
|
||||||
|
StreamingEvent::MessageDelta { delta, usage } => {
|
||||||
|
if delta.stop_reason.is_some() {
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: PartialUsage {
|
||||||
|
output_tokens: usage.output_tokens,
|
||||||
|
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
||||||
yield result;
|
yield result;
|
||||||
}
|
}
|
||||||
|
@ -184,19 +211,21 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_event(
|
fn handle_event(
|
||||||
event: &StreamingEvent,
|
event: &StreamingEvent,
|
||||||
current_tool_call: &mut Option<ToolCallState>,
|
current_tool_call: &mut Option<ToolCallState>,
|
||||||
) -> Option<Result<StreamingChoice, CompletionError>> {
|
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
|
||||||
match event {
|
match event {
|
||||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||||
ContentDelta::TextDelta { text } => {
|
ContentDelta::TextDelta { text } => {
|
||||||
if current_tool_call.is_none() {
|
if current_tool_call.is_none() {
|
||||||
return Some(Ok(StreamingChoice::Message(text.clone())));
|
return Some(Ok(RawStreamingChoice::Message(text.clone())));
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -227,7 +256,7 @@ fn handle_event(
|
||||||
&tool_call.input_json
|
&tool_call.input_json
|
||||||
};
|
};
|
||||||
match serde_json::from_str(json_str) {
|
match serde_json::from_str(json_str) {
|
||||||
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
|
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
|
||||||
tool_call.name,
|
tool_call.name,
|
||||||
tool_call.id,
|
tool_call.id,
|
||||||
json_value,
|
json_value,
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
// Azure OpenAI Streaming API
|
// Azure OpenAI Streaming API
|
||||||
// -----------------------------------------------------
|
// -----------------------------------------------------
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self
|
let builder = self
|
||||||
.client
|
.client
|
||||||
|
|
|
@ -6,8 +6,9 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::client::Client;
|
use super::client::Client;
|
||||||
|
use crate::completion::CompletionRequest;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CompletionResponse {
|
pub struct CompletionResponse {
|
||||||
|
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CompletionModel {
|
pub struct CompletionModel {
|
||||||
client: Client,
|
pub(crate) client: Client,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -430,16 +431,11 @@ impl CompletionModel {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl completion::CompletionModel for CompletionModel {
|
pub(crate) fn create_completion_request(
|
||||||
type Response = CompletionResponse;
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "worker", worker::send)]
|
|
||||||
async fn completion(
|
|
||||||
&self,
|
&self,
|
||||||
completion_request: completion::CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
let prompt = completion_request.prompt_with_context();
|
let prompt = completion_request.prompt_with_context();
|
||||||
|
|
||||||
let mut messages: Vec<message::Message> =
|
let mut messages: Vec<message::Message> =
|
||||||
|
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(ref params) = completion_request.additional_params {
|
||||||
|
Ok(json_utils::merge(request.clone(), params.clone()))
|
||||||
|
} else {
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl completion::CompletionModel for CompletionModel {
|
||||||
|
type Response = CompletionResponse;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "worker", worker::send)]
|
||||||
|
async fn completion(
|
||||||
|
&self,
|
||||||
|
completion_request: completion::CompletionRequest,
|
||||||
|
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||||
|
let request = self.create_completion_request(completion_request)?;
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Cohere request: {}",
|
"Cohere request: {}",
|
||||||
serde_json::to_string_pretty(&request)?
|
serde_json::to_string_pretty(&request)?
|
||||||
);
|
);
|
||||||
|
|
||||||
let response = self
|
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||||
.client
|
|
||||||
.post("/v2/chat")
|
|
||||||
.json(
|
|
||||||
&if let Some(ref params) = completion_request.additional_params {
|
|
||||||
json_utils::merge(request.clone(), params.clone())
|
|
||||||
} else {
|
|
||||||
request.clone()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let text_response = response.text().await?;
|
let text_response = response.text().await?;
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embeddings;
|
pub mod embeddings;
|
||||||
|
pub mod streaming;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use client::{ApiErrorResponse, ApiResponse};
|
pub use client::{ApiErrorResponse, ApiResponse};
|
||||||
|
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
|
||||||
// ================================================================
|
// ================================================================
|
||||||
|
|
||||||
/// `command-r-plus` completion model
|
/// `command-r-plus` completion model
|
||||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
pub const COMMAND_R_PLUS: &str = "command-r-plus";
|
||||||
/// `command-r` completion model
|
/// `command-r` completion model
|
||||||
pub const COMMAND_R: &str = "command-r";
|
pub const COMMAND_R: &str = "command-r";
|
||||||
/// `command` completion model
|
/// `command` completion model
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
|
use crate::providers::cohere::completion::Usage;
|
||||||
|
use crate::providers::cohere::CompletionModel;
|
||||||
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
|
use crate::{json_utils, streaming};
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case", tag = "type")]
|
||||||
|
enum StreamingEvent {
|
||||||
|
MessageStart,
|
||||||
|
ContentStart,
|
||||||
|
ContentDelta { delta: Option<Delta> },
|
||||||
|
ContentEnd,
|
||||||
|
ToolPlan,
|
||||||
|
ToolCallStart { delta: Option<Delta> },
|
||||||
|
ToolCallDelta { delta: Option<Delta> },
|
||||||
|
ToolCallEnd,
|
||||||
|
MessageEnd { delta: Option<MessageEndDelta> },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageContentDelta {
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageToolFunctionDelta {
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageToolCallDelta {
|
||||||
|
id: Option<String>,
|
||||||
|
function: Option<MessageToolFunctionDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageDelta {
|
||||||
|
content: Option<MessageContentDelta>,
|
||||||
|
tool_calls: Option<MessageToolCallDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct Delta {
|
||||||
|
message: Option<MessageDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageEndDelta {
|
||||||
|
usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
|
let request = self.create_completion_request(request)?;
|
||||||
|
let request = json_utils::merge(request, json!({"stream": true}));
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Cohere request: {}",
|
||||||
|
serde_json::to_string_pretty(&request)?
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(CompletionError::ProviderError(format!(
|
||||||
|
"{}: {}",
|
||||||
|
response.status(),
|
||||||
|
response.text().await?
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = Box::pin(stream! {
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut current_tool_call: Option<(String, String, String)> = None;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
let chunk = match chunk_result {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::from(e));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = match String::from_utf8(chunk.to_vec()) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
|
||||||
|
let Some(line) = line.strip_prefix("data: ") else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let event = {
|
||||||
|
let result = serde_json::from_str::<StreamingEvent>(line);
|
||||||
|
|
||||||
|
let Ok(event) = result else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
event
|
||||||
|
};
|
||||||
|
|
||||||
|
match event {
|
||||||
|
StreamingEvent::ContentDelta { delta: Some(delta) } => {
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(content) = &message.content else { continue; };
|
||||||
|
let Some(text) = &content.text else { continue; };
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::Message(text.clone()));
|
||||||
|
},
|
||||||
|
StreamingEvent::MessageEnd {delta: Some(delta)} => {
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: delta.usage.clone()
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
|
||||||
|
// Skip the delta if there's any missing information,
|
||||||
|
// though this *should* all be present
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||||
|
let Some(id) = tool_calls.id.clone() else { continue; };
|
||||||
|
let Some(function) = &tool_calls.function else { continue; };
|
||||||
|
let Some(name) = function.name.clone() else { continue; };
|
||||||
|
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||||
|
|
||||||
|
current_tool_call = Some((id, name, arguments));
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
|
||||||
|
// Skip the delta if there's any missing information,
|
||||||
|
// though this *should* all be present
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||||
|
let Some(function) = &tool_calls.function else { continue; };
|
||||||
|
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||||
|
|
||||||
|
if let Some(tc) = current_tool_call.clone() {
|
||||||
|
current_tool_call = Some((
|
||||||
|
tc.0,
|
||||||
|
tc.1,
|
||||||
|
format!("{}{}", tc.2, arguments)
|
||||||
|
));
|
||||||
|
};
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallEnd => {
|
||||||
|
let Some(tc) = current_tool_call.clone() else { continue; };
|
||||||
|
|
||||||
|
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::ToolCall(
|
||||||
|
tc.0,
|
||||||
|
tc.1,
|
||||||
|
args
|
||||||
|
));
|
||||||
|
|
||||||
|
current_tool_call = None;
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,8 +10,9 @@
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||||
extractor::ExtractorBuilder,
|
extractor::ExtractorBuilder,
|
||||||
|
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/v1/chat/completions").json(&request);
|
let builder = self.client.post("/v1/chat/completions").json(&request);
|
||||||
send_compatible_streaming_request(builder).await
|
send_compatible_streaming_request(builder).await
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
use super::openai;
|
use super::openai;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -609,7 +609,7 @@ pub mod gemini_api_types {
|
||||||
HarmCategoryCivicIntegrity,
|
HarmCategoryCivicIntegrity,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Clone, Default)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct UsageMetadata {
|
pub struct UsageMetadata {
|
||||||
pub prompt_token_count: i32,
|
pub prompt_token_count: i32,
|
||||||
|
|
|
@ -2,12 +2,17 @@ use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionError, CompletionRequest},
|
completion::{CompletionError, CompletionRequest},
|
||||||
streaming::{self, StreamingCompletionModel, StreamingResult},
|
streaming::{self, StreamingCompletionModel},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
#[derive(Debug, Deserialize, Default, Clone)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PartialUsage {
|
||||||
|
pub total_token_count: i32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
|
@ -15,13 +20,21 @@ pub struct StreamGenerateContentResponse {
|
||||||
/// Candidate responses from the model.
|
/// Candidate responses from the model.
|
||||||
pub candidates: Vec<ContentCandidate>,
|
pub candidates: Vec<ContentCandidate>,
|
||||||
pub model_version: Option<String>,
|
pub model_version: Option<String>,
|
||||||
|
pub usage_metadata: Option<PartialUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage_metadata: PartialUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let request = create_request_body(completion_request)?;
|
let request = create_request_body(completion_request)?;
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
|
@ -42,7 +55,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let stream = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
@ -74,13 +87,23 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
|
|
||||||
match choice.content.parts.first() {
|
match choice.content.parts.first() {
|
||||||
super::completion::gemini_api_types::Part::Text(text)
|
super::completion::gemini_api_types::Part::Text(text)
|
||||||
=> yield Ok(streaming::StreamingChoice::Message(text)),
|
=> yield Ok(streaming::RawStreamingChoice::Message(text)),
|
||||||
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
||||||
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||||
_ => panic!("Unsupported response type with streaming.")
|
_ => panic!("Unsupported response type with streaming.")
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if choice.finish_reason.is_some() {
|
||||||
|
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage_metadata: PartialUsage {
|
||||||
|
total_token_count: data.usage_metadata.unwrap().total_token_count,
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,8 @@
|
||||||
//! ```
|
//! ```
|
||||||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::providers::openai;
|
||||||
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use super::completion::CompletionModel;
|
use super::completion::CompletionModel;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils;
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::StreamingCompletionModel;
|
||||||
|
use crate::{json_utils, streaming};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
@ -55,14 +55,19 @@ struct CompletionChunk {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request = self.create_request_body(&completion_request)?;
|
let mut request = self.create_request_body(&completion_request)?;
|
||||||
|
|
||||||
// Enable streaming
|
// Enable streaming
|
||||||
merge_inplace(&mut request, json!({"stream": true}));
|
merge_inplace(
|
||||||
|
&mut request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(ref params) = completion_request.additional_params {
|
if let Some(ref params) = completion_request.additional_params {
|
||||||
merge_inplace(&mut request, params.clone());
|
merge_inplace(&mut request, params.clone());
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||||
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
merge_inplace(&mut request, json!({"stream": true}));
|
merge_inplace(
|
||||||
|
&mut request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
@ -526,8 +530,10 @@ mod image_generation {
|
||||||
// ======================================
|
// ======================================
|
||||||
// Hyperbolic Audio Generation API
|
// Hyperbolic Audio Generation API
|
||||||
// ======================================
|
// ======================================
|
||||||
|
use crate::providers::openai;
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
pub use audio_generation::*;
|
pub use audio_generation::*;
|
||||||
|
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
mod audio_generation {
|
mod audio_generation {
|
||||||
use super::{ApiResponse, Client};
|
use super::{ApiResponse, Client};
|
||||||
|
|
|
@ -8,8 +8,9 @@
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -47,7 +47,7 @@ use crate::{
|
||||||
extractor::ExtractorBuilder,
|
extractor::ExtractorBuilder,
|
||||||
json_utils, message,
|
json_utils, message,
|
||||||
message::{ImageDetail, Text},
|
message::{ImageDetail, Text},
|
||||||
Embed, OneOrMany,
|
streaming, Embed, OneOrMany,
|
||||||
};
|
};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
@ -405,8 +405,25 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub done_reason: Option<String>,
|
||||||
|
pub total_duration: Option<u64>,
|
||||||
|
pub load_duration: Option<u64>,
|
||||||
|
pub prompt_eval_count: Option<u64>,
|
||||||
|
pub prompt_eval_duration: Option<u64>,
|
||||||
|
pub eval_count: Option<u64>,
|
||||||
|
pub eval_duration: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request_payload = self.create_completion_request(request)?;
|
let mut request_payload = self.create_completion_request(request)?;
|
||||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||||
|
|
||||||
|
@ -426,7 +443,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
return Err(CompletionError::ProviderError(err_text));
|
return Err(CompletionError::ProviderError(err_text));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let stream = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
let chunk = match chunk_result {
|
let chunk = match chunk_result {
|
||||||
|
@ -456,22 +473,36 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
match response.message {
|
match response.message {
|
||||||
Message::Assistant{ content, tool_calls, .. } => {
|
Message::Assistant{ content, tool_calls, .. } => {
|
||||||
if !content.is_empty() {
|
if !content.is_empty() {
|
||||||
yield Ok(StreamingChoice::Message(content))
|
yield Ok(RawStreamingChoice::Message(content))
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool_call in tool_calls.iter() {
|
for tool_call in tool_calls.iter() {
|
||||||
let function = tool_call.function.clone();
|
let function = tool_call.function.clone();
|
||||||
|
|
||||||
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.done {
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
total_duration: response.total_duration,
|
||||||
|
load_duration: response.load_duration,
|
||||||
|
prompt_eval_count: response.prompt_eval_count,
|
||||||
|
prompt_eval_duration: response.prompt_eval_duration,
|
||||||
|
eval_count: response.eval_count,
|
||||||
|
eval_duration: response.eval_duration,
|
||||||
|
done_reason: response.done_reason,
|
||||||
|
}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,16 @@ use super::completion::CompletionModel;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils;
|
use crate::json_utils;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai::Usage;
|
||||||
use crate::streaming;
|
use crate::streaming;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use reqwest::RequestBuilder;
|
use reqwest::RequestBuilder;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
// ================================================================
|
// ================================================================
|
||||||
// OpenAI Completion Streaming API
|
// OpenAI Completion Streaming API
|
||||||
|
@ -25,10 +27,11 @@ pub struct StreamingFunction {
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub struct StreamingToolCall {
|
pub struct StreamingToolCall {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
|
pub id: Option<String>,
|
||||||
pub function: StreamingFunction,
|
pub function: StreamingFunction,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Debug)]
|
||||||
struct StreamingDelta {
|
struct StreamingDelta {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
|
@ -36,23 +39,34 @@ struct StreamingDelta {
|
||||||
tool_calls: Vec<StreamingToolCall>,
|
tool_calls: Vec<StreamingToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Debug)]
|
||||||
struct StreamingChoice {
|
struct StreamingChoice {
|
||||||
delta: StreamingDelta,
|
delta: StreamingDelta,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Debug)]
|
||||||
struct StreamingCompletionResponse {
|
struct StreamingCompletionChunk {
|
||||||
choices: Vec<StreamingChoice>,
|
choices: Vec<StreamingChoice>,
|
||||||
|
usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage: Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
send_compatible_streaming_request(builder).await
|
send_compatible_streaming_request(builder).await
|
||||||
|
@ -61,7 +75,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
|
|
||||||
pub async fn send_compatible_streaming_request(
|
pub async fn send_compatible_streaming_request(
|
||||||
request_builder: RequestBuilder,
|
request_builder: RequestBuilder,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
|
||||||
let response = request_builder.send().await?;
|
let response = request_builder.send().await?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
|
@ -73,11 +87,16 @@ pub async fn send_compatible_streaming_request(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle OpenAI Compatible SSE chunks
|
// Handle OpenAI Compatible SSE chunks
|
||||||
Ok(Box::pin(stream! {
|
let inner = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
|
||||||
|
let mut final_usage = Usage {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
total_tokens: 0
|
||||||
|
};
|
||||||
|
|
||||||
let mut partial_data = None;
|
let mut partial_data = None;
|
||||||
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
let chunk = match chunk_result {
|
let chunk = match chunk_result {
|
||||||
|
@ -100,8 +119,6 @@ pub async fn send_compatible_streaming_request(
|
||||||
for line in text.lines() {
|
for line in text.lines() {
|
||||||
let mut line = line.to_string();
|
let mut line = line.to_string();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// If there was a remaining part, concat with current line
|
// If there was a remaining part, concat with current line
|
||||||
if partial_data.is_some() {
|
if partial_data.is_some() {
|
||||||
line = format!("{}{}", partial_data.unwrap(), line);
|
line = format!("{}{}", partial_data.unwrap(), line);
|
||||||
|
@ -121,64 +138,85 @@ pub async fn send_compatible_streaming_request(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
|
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
|
||||||
|
|
||||||
let Ok(data) = data else {
|
let Ok(data) = data else {
|
||||||
|
let err = data.unwrap_err();
|
||||||
|
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let choice = data.choices.first().expect("Should have at least one choice");
|
|
||||||
|
|
||||||
let delta = &choice.delta;
|
if let Some(choice) = data.choices.first() {
|
||||||
|
|
||||||
if !delta.tool_calls.is_empty() {
|
let delta = &choice.delta;
|
||||||
for tool_call in &delta.tool_calls {
|
|
||||||
let function = tool_call.function.clone();
|
|
||||||
|
|
||||||
// Start of tool call
|
if !delta.tool_calls.is_empty() {
|
||||||
// name: Some(String)
|
for tool_call in &delta.tool_calls {
|
||||||
// arguments: None
|
let function = tool_call.function.clone();
|
||||||
if function.name.is_some() && function.arguments.is_empty() {
|
// Start of tool call
|
||||||
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
// name: Some(String)
|
||||||
|
// arguments: None
|
||||||
|
if function.name.is_some() && function.arguments.is_empty() {
|
||||||
|
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||||
|
|
||||||
|
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
|
||||||
|
}
|
||||||
|
// Part of tool call
|
||||||
|
// name: None
|
||||||
|
// arguments: Some(String)
|
||||||
|
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||||
|
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
|
||||||
|
debug!("Partial tool call received but tool call was never started.");
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_arguments = &tool_call.function.arguments;
|
||||||
|
let arguments = format!("{}{}", arguments, new_arguments);
|
||||||
|
|
||||||
|
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
|
||||||
|
}
|
||||||
|
// Entire tool call
|
||||||
|
else {
|
||||||
|
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||||
|
let name = function.name.expect("function name should be present for complete tool call");
|
||||||
|
let arguments = function.arguments;
|
||||||
|
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||||
|
debug!("Couldn't serialize '{}' as a json value", arguments);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Part of tool call
|
}
|
||||||
// name: None
|
|
||||||
// arguments: Some(String)
|
|
||||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
|
||||||
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let new_arguments = &tool_call.function.arguments;
|
if let Some(content) = &choice.delta.content {
|
||||||
let arguments = format!("{}{}", arguments, new_arguments);
|
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||||
|
|
||||||
calls.insert(tool_call.index, (name.clone(), arguments));
|
|
||||||
}
|
|
||||||
// Entire tool call
|
|
||||||
else {
|
|
||||||
let name = function.name.unwrap();
|
|
||||||
let arguments = function.arguments;
|
|
||||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(content) = &choice.delta.content {
|
|
||||||
yield Ok(streaming::StreamingChoice::Message(content.clone()))
|
if let Some(usage) = data.usage {
|
||||||
|
final_usage = usage.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (_, (name, arguments)) in calls {
|
for (_, (id, name, arguments)) in calls {
|
||||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
println!("{id} {name}");
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
|
||||||
}
|
}
|
||||||
}))
|
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: final_usage.clone()
|
||||||
|
}))
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(inner))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::completion::CompletionModel;
|
||||||
|
|
||||||
|
// ================================================================
|
||||||
|
// Main openrouter Client
|
||||||
|
// ================================================================
|
||||||
|
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
base_url: String,
|
||||||
|
http_client: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Create a new OpenRouter client with the given API key.
|
||||||
|
pub fn new(api_key: &str) -> Self {
|
||||||
|
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new OpenRouter client with the given API key and base API URL.
|
||||||
|
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
base_url: base_url.to_string(),
|
||||||
|
http_client: reqwest::Client::builder()
|
||||||
|
.default_headers({
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"Authorization",
|
||||||
|
format!("Bearer {}", api_key)
|
||||||
|
.parse()
|
||||||
|
.expect("Bearer token should parse"),
|
||||||
|
);
|
||||||
|
headers
|
||||||
|
})
|
||||||
|
.build()
|
||||||
|
.expect("OpenRouter reqwest client should build"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
||||||
|
/// Panics if the environment variable is not set.
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
||||||
|
Self::new(&api_key)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||||
|
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||||
|
self.http_client.post(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a completion model with the given name.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::providers::openrouter::{Client, self};
|
||||||
|
///
|
||||||
|
/// // Initialize the openrouter client
|
||||||
|
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||||
|
///
|
||||||
|
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
||||||
|
/// ```
|
||||||
|
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||||
|
CompletionModel::new(self.clone(), model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an agent builder with the given completion model.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::providers::openrouter::{Client, self};
|
||||||
|
///
|
||||||
|
/// // Initialize the Eternal client
|
||||||
|
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||||
|
///
|
||||||
|
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
||||||
|
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||||
|
/// .temperature(0.0)
|
||||||
|
/// .build();
|
||||||
|
/// ```
|
||||||
|
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||||
|
AgentBuilder::new(self.completion_model(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an extractor builder with the given completion model.
|
||||||
|
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
model: &str,
|
||||||
|
) -> ExtractorBuilder<T, CompletionModel> {
|
||||||
|
ExtractorBuilder::new(self.completion_model(model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ApiErrorResponse {
|
||||||
|
pub(crate) message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ApiResponse<T> {
|
||||||
|
Ok(T),
|
||||||
|
Err(ApiErrorResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: usize,
|
||||||
|
pub completion_tokens: usize,
|
||||||
|
pub total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Usage {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Prompt tokens: {} Total tokens: {}",
|
||||||
|
self.prompt_tokens, self.total_tokens
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,147 +1,16 @@
|
||||||
//! OpenRouter Inference API client and Rig integration
|
use serde::Deserialize;
|
||||||
//!
|
|
||||||
//! # Example
|
use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
|
||||||
//! ```
|
|
||||||
//! use rig::providers::openrouter;
|
|
||||||
//!
|
|
||||||
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
|
||||||
//!
|
|
||||||
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
|
||||||
//! ```
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
extractor::ExtractorBuilder,
|
|
||||||
json_utils,
|
json_utils,
|
||||||
providers::openai::Message,
|
providers::openai::Message,
|
||||||
OneOrMany,
|
OneOrMany,
|
||||||
};
|
};
|
||||||
use schemars::JsonSchema;
|
use serde_json::{json, Value};
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use super::openai::AssistantContent;
|
use crate::providers::openai::AssistantContent;
|
||||||
|
|
||||||
// ================================================================
|
|
||||||
// Main openrouter Client
|
|
||||||
// ================================================================
|
|
||||||
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Client {
|
|
||||||
base_url: String,
|
|
||||||
http_client: reqwest::Client,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Client {
|
|
||||||
/// Create a new OpenRouter client with the given API key.
|
|
||||||
pub fn new(api_key: &str) -> Self {
|
|
||||||
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new OpenRouter client with the given API key and base API URL.
|
|
||||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
|
||||||
Self {
|
|
||||||
base_url: base_url.to_string(),
|
|
||||||
http_client: reqwest::Client::builder()
|
|
||||||
.default_headers({
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
"Authorization",
|
|
||||||
format!("Bearer {}", api_key)
|
|
||||||
.parse()
|
|
||||||
.expect("Bearer token should parse"),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
})
|
|
||||||
.build()
|
|
||||||
.expect("OpenRouter reqwest client should build"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
|
||||||
/// Panics if the environment variable is not set.
|
|
||||||
pub fn from_env() -> Self {
|
|
||||||
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
|
||||||
Self::new(&api_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
|
||||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
|
||||||
self.http_client.post(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a completion model with the given name.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::providers::openrouter::{Client, self};
|
|
||||||
///
|
|
||||||
/// // Initialize the openrouter client
|
|
||||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
|
||||||
///
|
|
||||||
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
|
||||||
/// ```
|
|
||||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
|
||||||
CompletionModel::new(self.clone(), model)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an agent builder with the given completion model.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::providers::openrouter::{Client, self};
|
|
||||||
///
|
|
||||||
/// // Initialize the Eternal client
|
|
||||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
|
||||||
///
|
|
||||||
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
|
||||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
|
||||||
/// .temperature(0.0)
|
|
||||||
/// .build();
|
|
||||||
/// ```
|
|
||||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
|
||||||
AgentBuilder::new(self.completion_model(model))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an extractor builder with the given completion model.
|
|
||||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
) -> ExtractorBuilder<T, CompletionModel> {
|
|
||||||
ExtractorBuilder::new(self.completion_model(model))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ApiErrorResponse {
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum ApiResponse<T> {
|
|
||||||
Ok(T),
|
|
||||||
Err(ApiErrorResponse),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: usize,
|
|
||||||
pub completion_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for Usage {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"Prompt tokens: {} Total tokens: {}",
|
|
||||||
self.prompt_tokens, self.total_tokens
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ================================================================
|
// ================================================================
|
||||||
// OpenRouter Completion API
|
// OpenRouter Completion API
|
||||||
|
@ -241,7 +110,7 @@ pub struct Choice {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CompletionModel {
|
pub struct CompletionModel {
|
||||||
client: Client,
|
pub(crate) client: Client,
|
||||||
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
|
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
|
||||||
pub model: String,
|
pub model: String,
|
||||||
}
|
}
|
||||||
|
@ -253,16 +122,11 @@ impl CompletionModel {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl completion::CompletionModel for CompletionModel {
|
pub(crate) fn create_completion_request(
|
||||||
type Response = CompletionResponse;
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "worker", worker::send)]
|
|
||||||
async fn completion(
|
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Add preamble to chat history (if available)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
Some(preamble) => vec![Message::system(preamble)],
|
||||||
|
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
"temperature": completion_request.temperature,
|
"temperature": completion_request.temperature,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let request = if let Some(params) = completion_request.additional_params {
|
||||||
|
json_utils::merge(request, params)
|
||||||
|
} else {
|
||||||
|
request
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl completion::CompletionModel for CompletionModel {
|
||||||
|
type Response = CompletionResponse;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "worker", worker::send)]
|
||||||
|
async fn completion(
|
||||||
|
&self,
|
||||||
|
completion_request: CompletionRequest,
|
||||||
|
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||||
|
let request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("/chat/completions")
|
.post("/chat/completions")
|
||||||
.json(
|
.json(&request)
|
||||||
&if let Some(params) = completion_request.additional_params {
|
|
||||||
json_utils::merge(request, params)
|
|
||||||
} else {
|
|
||||||
request
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
//! OpenRouter Inference API client and Rig integration
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//! ```
|
||||||
|
//! use rig::providers::openrouter;
|
||||||
|
//!
|
||||||
|
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
||||||
|
//!
|
||||||
|
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
pub mod client;
|
||||||
|
pub mod completion;
|
||||||
|
pub mod streaming;
|
||||||
|
|
||||||
|
pub use client::*;
|
||||||
|
pub use completion::*;
|
|
@ -0,0 +1,313 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
json_utils,
|
||||||
|
message::{ToolCall, ToolFunction},
|
||||||
|
streaming::{self},
|
||||||
|
};
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use reqwest::RequestBuilder;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{CompletionError, CompletionRequest},
|
||||||
|
streaming::StreamingCompletionModel,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub choices: Vec<StreamingChoice>,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<ResponseUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct StreamingChoice {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub native_finish_reason: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<Value>,
|
||||||
|
pub index: usize,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub message: Option<MessageResponse>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub delta: Option<DeltaResponse>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<ErrorResponse>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct MessageResponse {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refusal: Option<Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct OpenRouterToolFunction {
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct OpenRouterToolCall {
|
||||||
|
pub index: usize,
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub r#type: Option<String>,
|
||||||
|
pub function: OpenRouterToolFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||||
|
pub struct ResponseUsage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub code: i32,
|
||||||
|
pub message: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<HashMap<String, Value>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct DeltaResponse {
|
||||||
|
pub role: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub native_finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FinalCompletionResponse {
|
||||||
|
pub usage: ResponseUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingCompletionModel for super::CompletionModel {
|
||||||
|
type StreamingResponse = FinalCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
completion_request: CompletionRequest,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
|
let request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
|
let request = json_utils::merge(request, json!({"stream": true}));
|
||||||
|
|
||||||
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
send_streaming_request(builder).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_streaming_request(
|
||||||
|
request_builder: RequestBuilder,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
|
||||||
|
let response = request_builder.send().await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(CompletionError::ProviderError(format!(
|
||||||
|
"{}: {}",
|
||||||
|
response.status(),
|
||||||
|
response.text().await?
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle OpenAI Compatible SSE chunks
|
||||||
|
let stream = Box::pin(stream! {
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut tool_calls = HashMap::new();
|
||||||
|
let mut partial_line = String::new();
|
||||||
|
let mut final_usage = None;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
let chunk = match chunk_result {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::from(e));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = match String::from_utf8(chunk.to_vec()) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
let mut line = line.to_string();
|
||||||
|
|
||||||
|
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
|
||||||
|
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle data: prefix
|
||||||
|
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
|
||||||
|
|
||||||
|
// If line starts with { but doesn't end with }, it's a partial JSON
|
||||||
|
if line.starts_with('{') && !line.ends_with('}') {
|
||||||
|
partial_line = line;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a partial line and this line ends with }, complete it
|
||||||
|
if !partial_line.is_empty() {
|
||||||
|
if line.ends_with('}') {
|
||||||
|
partial_line.push_str(&line);
|
||||||
|
line = partial_line;
|
||||||
|
partial_line = String::new();
|
||||||
|
} else {
|
||||||
|
partial_line.push_str(&line);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
|
||||||
|
Ok(data) => data,
|
||||||
|
Err(_) => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
let choice = data.choices.first().expect("Should have at least one choice");
|
||||||
|
|
||||||
|
// TODO this has to handle outputs like this:
|
||||||
|
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
|
||||||
|
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
|
||||||
|
if let Some(delta) = &choice.delta {
|
||||||
|
if !delta.tool_calls.is_empty() {
|
||||||
|
for tool_call in &delta.tool_calls {
|
||||||
|
let index = tool_call.index;
|
||||||
|
|
||||||
|
// Get or create tool call entry
|
||||||
|
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
|
||||||
|
id: String::new(),
|
||||||
|
function: ToolFunction {
|
||||||
|
name: String::new(),
|
||||||
|
arguments: serde_json::Value::Null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update fields if present
|
||||||
|
if let Some(id) = &tool_call.id {
|
||||||
|
if !id.is_empty() {
|
||||||
|
existing_tool_call.id = id.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(name) = &tool_call.function.name {
|
||||||
|
if !name.is_empty() {
|
||||||
|
existing_tool_call.function.name = name.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(chunk) = &tool_call.function.arguments {
|
||||||
|
// Convert current arguments to string if needed
|
||||||
|
let current_args = match &existing_tool_call.function.arguments {
|
||||||
|
serde_json::Value::Null => String::new(),
|
||||||
|
serde_json::Value::String(s) => s.clone(),
|
||||||
|
v => v.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Concatenate the new chunk
|
||||||
|
let combined = format!("{}{}", current_args, chunk);
|
||||||
|
|
||||||
|
// Try to parse as JSON if it looks complete
|
||||||
|
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
|
||||||
|
match serde_json::from_str(&combined) {
|
||||||
|
Ok(parsed) => existing_tool_call.function.arguments = parsed,
|
||||||
|
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
existing_tool_call.function.arguments = serde_json::Value::String(combined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(content) = &delta.content {
|
||||||
|
if !content.is_empty() {
|
||||||
|
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(usage) = data.usage {
|
||||||
|
final_usage = Some(usage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle message format
|
||||||
|
if let Some(message) = &choice.message {
|
||||||
|
if !message.tool_calls.is_empty() {
|
||||||
|
for tool_call in &message.tool_calls {
|
||||||
|
let name = tool_call.function.name.clone();
|
||||||
|
let id = tool_call.id.clone();
|
||||||
|
let arguments = if let Some(args) = &tool_call.function.arguments {
|
||||||
|
// Try to parse the string as JSON, fallback to string value
|
||||||
|
match serde_json::from_str(args) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => serde_json::Value::String(args.to_string()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
serde_json::Value::Null
|
||||||
|
};
|
||||||
|
let index = tool_call.index;
|
||||||
|
|
||||||
|
tool_calls.insert(index, ToolCall{
|
||||||
|
id: id.unwrap_or_default(),
|
||||||
|
function: ToolFunction {
|
||||||
|
name: name.unwrap_or_default(),
|
||||||
|
arguments,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !message.content.is_empty() {
|
||||||
|
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (_, tool_call) in tool_calls.into_iter() {
|
||||||
|
|
||||||
|
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
|
||||||
|
}
|
||||||
|
|
||||||
|
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
|
||||||
|
usage: final_usage.unwrap_or_default()
|
||||||
|
}))
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
|
}
|
|
@ -18,8 +18,9 @@ use crate::{
|
||||||
|
|
||||||
use crate::completion::CompletionRequest;
|
use crate::completion::CompletionRequest;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: completion::CompletionRequest,
|
completion_request: completion::CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
use super::completion::CompletionModel;
|
use super::completion::CompletionModel;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
|
use crate::streaming::StreamingCompletionResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionError, CompletionRequest},
|
completion::{CompletionError, CompletionRequest},
|
||||||
json_utils::merge,
|
json_utils::merge,
|
||||||
streaming::{StreamingCompletionModel, StreamingResult},
|
streaming::StreamingCompletionModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream_tokens": true}));
|
request = merge(request, json!({"stream_tokens": true}));
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::providers::xai::completion::CompletionModel;
|
use crate::providers::xai::completion::CompletionModel;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -11,59 +11,150 @@
|
||||||
|
|
||||||
use crate::agent::Agent;
|
use crate::agent::Agent;
|
||||||
use crate::completion::{
|
use crate::completion::{
|
||||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
|
||||||
|
CompletionResponse, Message,
|
||||||
};
|
};
|
||||||
|
use crate::message::{AssistantContent, ToolCall, ToolFunction};
|
||||||
|
use crate::OneOrMany;
|
||||||
use futures::{Stream, StreamExt};
|
use futures::{Stream, StreamExt};
|
||||||
use std::boxed::Box;
|
use std::boxed::Box;
|
||||||
use std::fmt::{Display, Formatter};
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
/// Enum representing a streaming chunk from the model
|
/// Enum representing a streaming chunk from the model
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum StreamingChoice {
|
pub enum RawStreamingChoice<R: Clone> {
|
||||||
/// A text chunk from a message response
|
/// A text chunk from a message response
|
||||||
Message(String),
|
Message(String),
|
||||||
|
|
||||||
/// A tool call response chunk
|
/// A tool call response chunk
|
||||||
ToolCall(String, String, serde_json::Value),
|
ToolCall(String, String, serde_json::Value),
|
||||||
|
|
||||||
|
/// The final response object, must be yielded if you want the
|
||||||
|
/// `response` field to be populated on the `StreamingCompletionResponse`
|
||||||
|
FinalResponse(R),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for StreamingChoice {
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
pub type StreamingResult<R> =
|
||||||
match self {
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||||
StreamingChoice::Message(text) => write!(f, "{}", text),
|
|
||||||
StreamingChoice::ToolCall(name, id, params) => {
|
#[cfg(target_arch = "wasm32")]
|
||||||
write!(f, "Tool call: {} {} {:?}", name, id, params)
|
pub type StreamingResult<R> =
|
||||||
}
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
|
||||||
|
|
||||||
|
/// The response from a streaming completion request;
|
||||||
|
/// message and response are populated at the end of the
|
||||||
|
/// `inner` stream.
|
||||||
|
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
||||||
|
inner: StreamingResult<R>,
|
||||||
|
text: String,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
/// The final aggregated message from the stream
|
||||||
|
/// contains all text and tool calls generated
|
||||||
|
pub choice: OneOrMany<AssistantContent>,
|
||||||
|
/// The final response from the stream, may be `None`
|
||||||
|
/// if the provider didn't yield it during the stream
|
||||||
|
pub response: Option<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
||||||
|
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
text: "".to_string(),
|
||||||
|
tool_calls: vec![],
|
||||||
|
choice: OneOrMany::one(AssistantContent::text("")),
|
||||||
|
response: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
|
||||||
pub type StreamingResult =
|
fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
|
||||||
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
|
CompletionResponse {
|
||||||
|
choice: value.choice,
|
||||||
|
raw_response: value.response,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "wasm32")]
|
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
|
type Item = Result<AssistantContent, CompletionError>;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let stream = self.get_mut();
|
||||||
|
|
||||||
|
match stream.inner.as_mut().poll_next(cx) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
Poll::Ready(None) => {
|
||||||
|
// This is run at the end of the inner stream to collect all tokens into
|
||||||
|
// a single unified `Message`.
|
||||||
|
let mut choice = vec![];
|
||||||
|
|
||||||
|
stream.tool_calls.iter().for_each(|tc| {
|
||||||
|
choice.push(AssistantContent::ToolCall(tc.clone()));
|
||||||
|
});
|
||||||
|
|
||||||
|
// This is required to ensure there's always at least one item in the content
|
||||||
|
if choice.is_empty() || !stream.text.is_empty() {
|
||||||
|
choice.insert(0, AssistantContent::text(stream.text.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.choice = OneOrMany::many(choice)
|
||||||
|
.expect("There should be at least one assistant message");
|
||||||
|
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||||
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||||
|
RawStreamingChoice::Message(text) => {
|
||||||
|
// Forward the streaming tokens to the outer stream
|
||||||
|
// and concat the text together
|
||||||
|
stream.text = format!("{}{}", stream.text, text.clone());
|
||||||
|
Poll::Ready(Some(Ok(AssistantContent::text(text))))
|
||||||
|
}
|
||||||
|
RawStreamingChoice::ToolCall(id, name, args) => {
|
||||||
|
// Keep track of each tool call to aggregate the final message later
|
||||||
|
// and pass it to the outer stream
|
||||||
|
stream.tool_calls.push(ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
function: ToolFunction {
|
||||||
|
name: name.clone(),
|
||||||
|
arguments: args.clone(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
|
||||||
|
}
|
||||||
|
RawStreamingChoice::FinalResponse(response) => {
|
||||||
|
// Set the final response field and return the next item in the stream
|
||||||
|
stream.response = Some(response);
|
||||||
|
|
||||||
|
stream.poll_next_unpin(cx)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming prompt interface
|
/// Trait for high-level streaming prompt interface
|
||||||
pub trait StreamingPrompt: Send + Sync {
|
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
|
||||||
/// Stream a simple prompt to the model
|
/// Stream a simple prompt to the model
|
||||||
fn stream_prompt(
|
fn stream_prompt(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming chat interface
|
/// Trait for high-level streaming chat interface
|
||||||
pub trait StreamingChat: Send + Sync {
|
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
|
||||||
/// Stream a chat with history to the model
|
/// Stream a chat with history to the model
|
||||||
fn stream_chat(
|
fn stream_chat(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
chat_history: Vec<Message>,
|
chat_history: Vec<Message>,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for low-level streaming completion interface
|
/// Trait for low-level streaming completion interface
|
||||||
|
@ -78,29 +169,35 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
||||||
|
|
||||||
/// Trait defining a streaming completion model
|
/// Trait defining a streaming completion model
|
||||||
pub trait StreamingCompletionModel: CompletionModel {
|
pub trait StreamingCompletionModel: CompletionModel {
|
||||||
|
type StreamingResponse: Clone + Unpin;
|
||||||
/// Stream a completion response for the given request
|
/// Stream a completion response for the given request
|
||||||
fn stream(
|
fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<
|
||||||
|
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
|
||||||
|
>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// helper function to stream a completion request to stdout
|
/// helper function to stream a completion request to stdout
|
||||||
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||||
agent: Agent<M>,
|
agent: Agent<M>,
|
||||||
stream: &mut StreamingResult,
|
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
|
||||||
) -> Result<(), std::io::Error> {
|
) -> Result<(), std::io::Error> {
|
||||||
print!("Response: ");
|
print!("Response: ");
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
match chunk {
|
match chunk {
|
||||||
Ok(StreamingChoice::Message(text)) => {
|
Ok(AssistantContent::Text(text)) => {
|
||||||
print!("{}", text);
|
print!("{}", text.text);
|
||||||
std::io::Write::flush(&mut std::io::stdout())?;
|
std::io::Write::flush(&mut std::io::stdout())?;
|
||||||
}
|
}
|
||||||
Ok(StreamingChoice::ToolCall(name, _, params)) => {
|
Ok(AssistantContent::ToolCall(tool_call)) => {
|
||||||
let res = agent
|
let res = agent
|
||||||
.tools
|
.tools
|
||||||
.call(&name, params.to_string())
|
.call(
|
||||||
|
&tool_call.function.name,
|
||||||
|
tool_call.function.arguments.to_string(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
println!("\nResult: {}", res);
|
println!("\nResult: {}", res);
|
||||||
|
@ -111,6 +208,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(); // New line after streaming completes
|
println!(); // New line after streaming completes
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
Loading…
Reference in New Issue