mirror of https://github.com/0xplaygrounds/rig
Merge a78969fd9a
into 33e8fc7a65
This commit is contained in:
commit
d3f6857019
|
@ -19,5 +19,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
use rig::providers::cohere;
|
||||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create streaming agent with a single context prompt
|
||||
let agent = cohere::Client::from_env()
|
||||
.agent(cohere::COMMAND)
|
||||
.preamble("Be precise and concise.")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
// Stream the response and print chunks as they arrive
|
||||
let mut stream = agent
|
||||
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||
.await?;
|
||||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use anyhow::Result;
|
||||
use rig::streaming::stream_to_stdout;
|
||||
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt().init();
|
||||
// Create agent with a single context prompt and two tools
|
||||
let calculator_agent = providers::cohere::Client::from_env()
|
||||
.agent(providers::cohere::COMMAND_R)
|
||||
.preamble(
|
||||
"You are a calculator here to help the user perform arithmetic
|
||||
operations. Use the tools provided to answer the user's question.
|
||||
make your answer long, so we can test the streaming functionality,
|
||||
like 20 words",
|
||||
)
|
||||
.max_tokens(1024)
|
||||
.tool(Adder)
|
||||
.tool(Subtract)
|
||||
.build();
|
||||
|
||||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!(
|
||||
"Usage: {:?} tokens",
|
||||
response.usage_metadata.total_token_count
|
||||
);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -107,5 +107,15 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!(
|
||||
"Usage: {:?} tokens",
|
||||
response.usage_metadata.total_token_count
|
||||
);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.eval_count);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.eval_count);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -17,5 +17,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?}", response.usage)
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?}", response.usage)
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
use anyhow::Result;
|
||||
use rig::streaming::stream_to_stdout;
|
||||
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt().init();
|
||||
// Create agent with a single context prompt and two tools
|
||||
let calculator_agent = providers::openrouter::Client::from_env()
|
||||
.agent(providers::openrouter::GEMINI_FLASH_2_0)
|
||||
.preamble(
|
||||
"You are a calculator here to help the user perform arithmetic
|
||||
operations. Use the tools provided to answer the user's question.
|
||||
make your answer long, so we can test the streaming functionality,
|
||||
like 20 words",
|
||||
)
|
||||
.max_tokens(1024)
|
||||
.tool(Adder)
|
||||
.tool(Subtract)
|
||||
.build();
|
||||
|
||||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?}", response.usage)
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.choice);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -110,23 +110,20 @@ use std::collections::HashMap;
|
|||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::streaming::StreamingCompletionResponse;
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
message::AssistantContent,
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
|
||||
tool::{Tool, ToolSet},
|
||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
|
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
|
|
|
@ -67,7 +67,7 @@ use std::collections::HashMap;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::OneOrMany;
|
||||
use crate::{
|
||||
json_utils,
|
||||
|
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
|
||||
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
||||
/// Stream the completion request
|
||||
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
|
||||
pub async fn stream(
|
||||
self,
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
let model = self.model.clone();
|
||||
model.stream(self.build()).await
|
||||
}
|
||||
|
|
|
@ -8,7 +8,8 @@ use super::decoders::sse::from_response as sse_from_response;
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message::MessageError;
|
||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming;
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
|
@ -61,7 +62,7 @@ pub struct MessageDelta {
|
|||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct PartialUsage {
|
||||
pub output_tokens: usize,
|
||||
#[serde(default)]
|
||||
|
@ -75,11 +76,18 @@ struct ToolCallState {
|
|||
input_json: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage: PartialUsage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
||||
tokens
|
||||
} else if let Some(tokens) = self.default_max_tokens {
|
||||
|
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||
let sse_stream = sse_from_response(response);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
|
||||
let mut current_tool_call: Option<ToolCallState> = None;
|
||||
let mut sse_stream = Box::pin(sse_stream);
|
||||
let mut input_tokens = 0;
|
||||
|
||||
while let Some(sse_result) = sse_stream.next().await {
|
||||
match sse_result {
|
||||
|
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
// Parse the SSE data as a StreamingEvent
|
||||
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
||||
Ok(event) => {
|
||||
match &event {
|
||||
StreamingEvent::MessageStart { message } => {
|
||||
input_tokens = message.usage.input_tokens;
|
||||
},
|
||||
StreamingEvent::MessageDelta { delta, usage } => {
|
||||
if delta.stop_reason.is_some() {
|
||||
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: PartialUsage {
|
||||
output_tokens: usage.output_tokens,
|
||||
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
||||
yield result;
|
||||
}
|
||||
|
@ -184,19 +211,21 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_event(
|
||||
event: &StreamingEvent,
|
||||
current_tool_call: &mut Option<ToolCallState>,
|
||||
) -> Option<Result<StreamingChoice, CompletionError>> {
|
||||
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
|
||||
match event {
|
||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
if current_tool_call.is_none() {
|
||||
return Some(Ok(StreamingChoice::Message(text.clone())));
|
||||
return Some(Ok(RawStreamingChoice::Message(text.clone())));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
@ -227,7 +256,7 @@ fn handle_event(
|
|||
&tool_call.input_json
|
||||
};
|
||||
match serde_json::from_str(json_str) {
|
||||
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
|
||||
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
|
||||
tool_call.name,
|
||||
tool_call.id,
|
||||
json_value,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
// Azure OpenAI Streaming API
|
||||
// -----------------------------------------------------
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self
|
||||
.client
|
||||
|
|
|
@ -6,8 +6,9 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::client::Client;
|
||||
use crate::completion::CompletionRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
|
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub(crate) client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
|
@ -430,16 +431,11 @@ impl CompletionModel {
|
|||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
let prompt = completion_request.prompt_with_context();
|
||||
|
||||
let mut messages: Vec<message::Message> =
|
||||
|
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
|
|||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||
});
|
||||
|
||||
if let Some(ref params) = completion_request.additional_params {
|
||||
Ok(json_utils::merge(request.clone(), params.clone()))
|
||||
} else {
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
tracing::debug!(
|
||||
"Cohere request: {}",
|
||||
serde_json::to_string_pretty(&request)?
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v2/chat")
|
||||
.json(
|
||||
&if let Some(ref params) = completion_request.additional_params {
|
||||
json_utils::merge(request.clone(), params.clone())
|
||||
} else {
|
||||
request.clone()
|
||||
},
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let text_response = response.text().await?;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod embeddings;
|
||||
pub mod streaming;
|
||||
|
||||
pub use client::Client;
|
||||
pub use client::{ApiErrorResponse, ApiResponse};
|
||||
|
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
|
|||
// ================================================================
|
||||
|
||||
/// `command-r-plus` completion model
|
||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
||||
pub const COMMAND_R_PLUS: &str = "command-r-plus";
|
||||
/// `command-r` completion model
|
||||
pub const COMMAND_R: &str = "command-r";
|
||||
/// `command` completion model
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::providers::cohere::completion::Usage;
|
||||
use crate::providers::cohere::CompletionModel;
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use crate::{json_utils, streaming};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case", tag = "type")]
|
||||
enum StreamingEvent {
|
||||
MessageStart,
|
||||
ContentStart,
|
||||
ContentDelta { delta: Option<Delta> },
|
||||
ContentEnd,
|
||||
ToolPlan,
|
||||
ToolCallStart { delta: Option<Delta> },
|
||||
ToolCallDelta { delta: Option<Delta> },
|
||||
ToolCallEnd,
|
||||
MessageEnd { delta: Option<MessageEndDelta> },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageContentDelta {
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageToolFunctionDelta {
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageToolCallDelta {
|
||||
id: Option<String>,
|
||||
function: Option<MessageToolFunctionDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageDelta {
|
||||
content: Option<MessageContentDelta>,
|
||||
tool_calls: Option<MessageToolCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Delta {
|
||||
message: Option<MessageDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageEndDelta {
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = self.create_completion_request(request)?;
|
||||
let request = json_utils::merge(request, json!({"stream": true}));
|
||||
|
||||
tracing::debug!(
|
||||
"Cohere request: {}",
|
||||
serde_json::to_string_pretty(&request)?
|
||||
);
|
||||
|
||||
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut current_tool_call: Option<(String, String, String)> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::from(e));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match String::from_utf8(chunk.to_vec()) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
for line in text.lines() {
|
||||
|
||||
let Some(line) = line.strip_prefix("data: ") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = {
|
||||
let result = serde_json::from_str::<StreamingEvent>(line);
|
||||
|
||||
let Ok(event) = result else {
|
||||
continue;
|
||||
};
|
||||
|
||||
event
|
||||
};
|
||||
|
||||
match event {
|
||||
StreamingEvent::ContentDelta { delta: Some(delta) } => {
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(content) = &message.content else { continue; };
|
||||
let Some(text) = &content.text else { continue; };
|
||||
|
||||
yield Ok(RawStreamingChoice::Message(text.clone()));
|
||||
},
|
||||
StreamingEvent::MessageEnd {delta: Some(delta)} => {
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: delta.usage.clone()
|
||||
}));
|
||||
},
|
||||
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
|
||||
// Skip the delta if there's any missing information,
|
||||
// though this *should* all be present
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||
let Some(id) = tool_calls.id.clone() else { continue; };
|
||||
let Some(function) = &tool_calls.function else { continue; };
|
||||
let Some(name) = function.name.clone() else { continue; };
|
||||
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||
|
||||
current_tool_call = Some((id, name, arguments));
|
||||
},
|
||||
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
|
||||
// Skip the delta if there's any missing information,
|
||||
// though this *should* all be present
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||
let Some(function) = &tool_calls.function else { continue; };
|
||||
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||
|
||||
if let Some(tc) = current_tool_call.clone() {
|
||||
current_tool_call = Some((
|
||||
tc.0,
|
||||
tc.1,
|
||||
format!("{}{}", tc.2, arguments)
|
||||
));
|
||||
};
|
||||
},
|
||||
StreamingEvent::ToolCallEnd => {
|
||||
let Some(tc) = current_tool_call.clone() else { continue; };
|
||||
|
||||
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
|
||||
|
||||
yield Ok(RawStreamingChoice::ToolCall(
|
||||
tc.0,
|
||||
tc.1,
|
||||
args
|
||||
));
|
||||
|
||||
current_tool_call = None;
|
||||
},
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
|
@ -10,8 +10,9 @@
|
|||
//! ```
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
|
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/v1/chat/completions").json(&request);
|
||||
send_compatible_streaming_request(builder).await
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
use super::openai;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -609,7 +609,7 @@ pub mod gemini_api_types {
|
|||
HarmCategoryCivicIntegrity,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub prompt_token_count: i32,
|
||||
|
|
|
@ -2,12 +2,17 @@ use async_stream::stream;
|
|||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
streaming::{self, StreamingCompletionModel, StreamingResult},
|
||||
streaming::{self, StreamingCompletionModel},
|
||||
};
|
||||
|
||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||
#[derive(Debug, Deserialize, Default, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PartialUsage {
|
||||
pub total_token_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
@ -15,13 +20,21 @@ pub struct StreamGenerateContentResponse {
|
|||
/// Candidate responses from the model.
|
||||
pub candidates: Vec<ContentCandidate>,
|
||||
pub model_version: Option<String>,
|
||||
pub usage_metadata: Option<PartialUsage>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage_metadata: PartialUsage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = create_request_body(completion_request)?;
|
||||
|
||||
let response = self
|
||||
|
@ -42,7 +55,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
)));
|
||||
}
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
|
@ -74,13 +87,23 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
|
||||
match choice.content.parts.first() {
|
||||
super::completion::gemini_api_types::Part::Text(text)
|
||||
=> yield Ok(streaming::StreamingChoice::Message(text)),
|
||||
=> yield Ok(streaming::RawStreamingChoice::Message(text)),
|
||||
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
||||
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||
_ => panic!("Unsupported response type with streaming.")
|
||||
};
|
||||
|
||||
if choice.finish_reason.is_some() {
|
||||
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage_metadata: PartialUsage {
|
||||
total_token_count: data.usage_metadata.unwrap().total_token_count,
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
//! ```
|
||||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::providers::openai;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use super::completion::CompletionModel;
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils;
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
|
||||
use crate::streaming::StreamingCompletionModel;
|
||||
use crate::{json_utils, streaming};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::convert::Infallible;
|
||||
|
@ -55,14 +55,19 @@ struct CompletionChunk {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request = self.create_request_body(&completion_request)?;
|
||||
|
||||
// Enable streaming
|
||||
merge_inplace(&mut request, json!({"stream": true}));
|
||||
merge_inplace(
|
||||
&mut request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
if let Some(ref params) = completion_request.additional_params {
|
||||
merge_inplace(&mut request, params.clone());
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
merge_inplace(&mut request, json!({"stream": true}));
|
||||
merge_inplace(
|
||||
&mut request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
@ -526,8 +530,10 @@ mod image_generation {
|
|||
// ======================================
|
||||
// Hyperbolic Audio Generation API
|
||||
// ======================================
|
||||
use crate::providers::openai;
|
||||
#[cfg(feature = "audio")]
|
||||
pub use audio_generation::*;
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
mod audio_generation {
|
||||
use super::{ApiResponse, Client};
|
||||
|
|
|
@ -8,8 +8,9 @@
|
|||
//!
|
||||
//! ```
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||
//! ```
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -47,7 +47,7 @@ use crate::{
|
|||
extractor::ExtractorBuilder,
|
||||
json_utils, message,
|
||||
message::{ImageDetail, Text},
|
||||
Embed, OneOrMany,
|
||||
streaming, Embed, OneOrMany,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
|
@ -405,8 +405,25 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub done_reason: Option<String>,
|
||||
pub total_duration: Option<u64>,
|
||||
pub load_duration: Option<u64>,
|
||||
pub prompt_eval_count: Option<u64>,
|
||||
pub prompt_eval_duration: Option<u64>,
|
||||
pub eval_count: Option<u64>,
|
||||
pub eval_duration: Option<u64>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request_payload = self.create_completion_request(request)?;
|
||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||
|
||||
|
@ -426,7 +443,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
return Err(CompletionError::ProviderError(err_text));
|
||||
}
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
|
@ -456,22 +473,36 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
match response.message {
|
||||
Message::Assistant{ content, tool_calls, .. } => {
|
||||
if !content.is_empty() {
|
||||
yield Ok(StreamingChoice::Message(content))
|
||||
yield Ok(RawStreamingChoice::Message(content))
|
||||
}
|
||||
|
||||
for tool_call in tool_calls.iter() {
|
||||
let function = tool_call.function.clone();
|
||||
|
||||
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||
yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if response.done {
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
total_duration: response.total_duration,
|
||||
load_duration: response.load_duration,
|
||||
prompt_eval_count: response.prompt_eval_count,
|
||||
prompt_eval_duration: response.prompt_eval_duration,
|
||||
eval_count: response.eval_count,
|
||||
eval_duration: response.eval_duration,
|
||||
done_reason: response.done_reason,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,14 +2,16 @@ use super::completion::CompletionModel;
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::Usage;
|
||||
use crate::streaming;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tracing::debug;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Completion Streaming API
|
||||
|
@ -25,10 +27,11 @@ pub struct StreamingFunction {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StreamingToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: StreamingFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
|
@ -36,23 +39,34 @@ struct StreamingDelta {
|
|||
tool_calls: Vec<StreamingToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingChoice {
|
||||
delta: StreamingDelta,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingCompletionResponse {
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingCompletionChunk {
|
||||
choices: Vec<StreamingChoice>,
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
send_compatible_streaming_request(builder).await
|
||||
|
@ -61,7 +75,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
|
||||
pub async fn send_compatible_streaming_request(
|
||||
request_builder: RequestBuilder,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
|
||||
let response = request_builder.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
|
@ -73,11 +87,16 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
|
||||
// Handle OpenAI Compatible SSE chunks
|
||||
Ok(Box::pin(stream! {
|
||||
let inner = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let mut final_usage = Usage {
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0
|
||||
};
|
||||
|
||||
let mut partial_data = None;
|
||||
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
||||
let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
|
@ -100,8 +119,6 @@ pub async fn send_compatible_streaming_request(
|
|||
for line in text.lines() {
|
||||
let mut line = line.to_string();
|
||||
|
||||
|
||||
|
||||
// If there was a remaining part, concat with current line
|
||||
if partial_data.is_some() {
|
||||
line = format!("{}{}", partial_data.unwrap(), line);
|
||||
|
@ -121,64 +138,85 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
}
|
||||
|
||||
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
|
||||
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
|
||||
|
||||
let Ok(data) = data else {
|
||||
let err = data.unwrap_err();
|
||||
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
|
||||
continue;
|
||||
};
|
||||
|
||||
let choice = data.choices.first().expect("Should have at least one choice");
|
||||
|
||||
let delta = &choice.delta;
|
||||
if let Some(choice) = data.choices.first() {
|
||||
|
||||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let function = tool_call.function.clone();
|
||||
let delta = &choice.delta;
|
||||
|
||||
// Start of tool call
|
||||
// name: Some(String)
|
||||
// arguments: None
|
||||
if function.name.is_some() && function.arguments.is_empty() {
|
||||
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
||||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let function = tool_call.function.clone();
|
||||
// Start of tool call
|
||||
// name: Some(String)
|
||||
// arguments: None
|
||||
if function.name.is_some() && function.arguments.is_empty() {
|
||||
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||
|
||||
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
|
||||
}
|
||||
// Part of tool call
|
||||
// name: None
|
||||
// arguments: Some(String)
|
||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
|
||||
debug!("Partial tool call received but tool call was never started.");
|
||||
continue;
|
||||
};
|
||||
|
||||
let new_arguments = &tool_call.function.arguments;
|
||||
let arguments = format!("{}{}", arguments, new_arguments);
|
||||
|
||||
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
|
||||
}
|
||||
// Entire tool call
|
||||
else {
|
||||
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||
let name = function.name.expect("function name should be present for complete tool call");
|
||||
let arguments = function.arguments;
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
debug!("Couldn't serialize '{}' as a json value", arguments);
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
|
||||
}
|
||||
}
|
||||
// Part of tool call
|
||||
// name: None
|
||||
// arguments: Some(String)
|
||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
|
||||
let new_arguments = &tool_call.function.arguments;
|
||||
let arguments = format!("{}{}", arguments, new_arguments);
|
||||
|
||||
calls.insert(tool_call.index, (name.clone(), arguments));
|
||||
}
|
||||
// Entire tool call
|
||||
else {
|
||||
let name = function.name.unwrap();
|
||||
let arguments = function.arguments;
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
}
|
||||
if let Some(content) = &choice.delta.content {
|
||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &choice.delta.content {
|
||||
yield Ok(streaming::StreamingChoice::Message(content.clone()))
|
||||
|
||||
if let Some(usage) = data.usage {
|
||||
final_usage = usage.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, (name, arguments)) in calls {
|
||||
for (_, (id, name, arguments)) in calls {
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
println!("{id} {name}");
|
||||
|
||||
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
|
||||
}
|
||||
}))
|
||||
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: final_usage.clone()
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(inner))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::completion::CompletionModel;
|
||||
|
||||
// ================================================================
|
||||
// Main openrouter Client
|
||||
// ================================================================
|
||||
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenRouter client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenRouter client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("OpenRouter reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the openrouter client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the Eternal client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
pub(crate) message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,147 +1,16 @@
|
|||
//! OpenRouter Inference API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::openrouter;
|
||||
//!
|
||||
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
//! ```
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
|
||||
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
json_utils,
|
||||
providers::openai::Message,
|
||||
OneOrMany,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::openai::AssistantContent;
|
||||
|
||||
// ================================================================
|
||||
// Main openrouter Client
|
||||
// ================================================================
|
||||
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenRouter client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenRouter client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("OpenRouter reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the openrouter client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the Eternal client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
use crate::providers::openai::AssistantContent;
|
||||
|
||||
// ================================================================
|
||||
// OpenRouter Completion API
|
||||
|
@ -241,7 +110,7 @@ pub struct Choice {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub(crate) client: Client,
|
||||
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
|
||||
pub model: String,
|
||||
}
|
||||
|
@ -253,16 +122,11 @@ impl CompletionModel {
|
|||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
|
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
|
|||
"temperature": completion_request.temperature,
|
||||
});
|
||||
|
||||
let request = if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(
|
||||
&if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
},
|
||||
)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
//! OpenRouter Inference API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::openrouter;
|
||||
//!
|
||||
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
//! ```
|
||||
|
||||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod streaming;
|
||||
|
||||
pub use client::*;
|
||||
pub use completion::*;
|
|
@ -0,0 +1,313 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
json_utils,
|
||||
message::{ToolCall, ToolFunction},
|
||||
streaming::{self},
|
||||
};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
streaming::StreamingCompletionModel,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub id: String,
|
||||
pub choices: Vec<StreamingChoice>,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub object: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<ResponseUsage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct StreamingChoice {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub native_finish_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<Value>,
|
||||
pub index: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<MessageResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub delta: Option<DeltaResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<ErrorResponse>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageResponse {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refusal: Option<Value>,
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OpenRouterToolFunction {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OpenRouterToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub r#type: Option<String>,
|
||||
pub function: OpenRouterToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ResponseUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ErrorResponse {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DeltaResponse {
|
||||
pub role: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub native_finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FinalCompletionResponse {
|
||||
pub usage: ResponseUsage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for super::CompletionModel {
|
||||
type StreamingResponse = FinalCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let request = json_utils::merge(request, json!({"stream": true}));
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
send_streaming_request(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_streaming_request(
|
||||
request_builder: RequestBuilder,
|
||||
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
|
||||
let response = request_builder.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
// Handle OpenAI Compatible SSE chunks
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_calls = HashMap::new();
|
||||
let mut partial_line = String::new();
|
||||
let mut final_usage = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::from(e));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match String::from_utf8(chunk.to_vec()) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
for line in text.lines() {
|
||||
let mut line = line.to_string();
|
||||
|
||||
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
|
||||
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle data: prefix
|
||||
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
|
||||
|
||||
// If line starts with { but doesn't end with }, it's a partial JSON
|
||||
if line.starts_with('{') && !line.ends_with('}') {
|
||||
partial_line = line;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we have a partial line and this line ends with }, complete it
|
||||
if !partial_line.is_empty() {
|
||||
if line.ends_with('}') {
|
||||
partial_line.push_str(&line);
|
||||
line = partial_line;
|
||||
partial_line = String::new();
|
||||
} else {
|
||||
partial_line.push_str(&line);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let choice = data.choices.first().expect("Should have at least one choice");
|
||||
|
||||
// TODO this has to handle outputs like this:
|
||||
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
|
||||
if let Some(delta) = &choice.delta {
|
||||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let index = tool_call.index;
|
||||
|
||||
// Get or create tool call entry
|
||||
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
|
||||
id: String::new(),
|
||||
function: ToolFunction {
|
||||
name: String::new(),
|
||||
arguments: serde_json::Value::Null,
|
||||
},
|
||||
});
|
||||
|
||||
// Update fields if present
|
||||
if let Some(id) = &tool_call.id {
|
||||
if !id.is_empty() {
|
||||
existing_tool_call.id = id.clone();
|
||||
}
|
||||
}
|
||||
if let Some(name) = &tool_call.function.name {
|
||||
if !name.is_empty() {
|
||||
existing_tool_call.function.name = name.clone();
|
||||
}
|
||||
}
|
||||
if let Some(chunk) = &tool_call.function.arguments {
|
||||
// Convert current arguments to string if needed
|
||||
let current_args = match &existing_tool_call.function.arguments {
|
||||
serde_json::Value::Null => String::new(),
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
v => v.to_string(),
|
||||
};
|
||||
|
||||
// Concatenate the new chunk
|
||||
let combined = format!("{}{}", current_args, chunk);
|
||||
|
||||
// Try to parse as JSON if it looks complete
|
||||
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
|
||||
match serde_json::from_str(&combined) {
|
||||
Ok(parsed) => existing_tool_call.function.arguments = parsed,
|
||||
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
|
||||
}
|
||||
} else {
|
||||
existing_tool_call.function.arguments = serde_json::Value::String(combined);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
if !content.is_empty() {
|
||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = data.usage {
|
||||
final_usage = Some(usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message format
|
||||
if let Some(message) = &choice.message {
|
||||
if !message.tool_calls.is_empty() {
|
||||
for tool_call in &message.tool_calls {
|
||||
let name = tool_call.function.name.clone();
|
||||
let id = tool_call.id.clone();
|
||||
let arguments = if let Some(args) = &tool_call.function.arguments {
|
||||
// Try to parse the string as JSON, fallback to string value
|
||||
match serde_json::from_str(args) {
|
||||
Ok(v) => v,
|
||||
Err(_) => serde_json::Value::String(args.to_string()),
|
||||
}
|
||||
} else {
|
||||
serde_json::Value::Null
|
||||
};
|
||||
let index = tool_call.index;
|
||||
|
||||
tool_calls.insert(index, ToolCall{
|
||||
id: id.unwrap_or_default(),
|
||||
function: ToolFunction {
|
||||
name: name.unwrap_or_default(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !message.content.is_empty() {
|
||||
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, tool_call) in tool_calls.into_iter() {
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
|
||||
}
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
|
||||
usage: final_usage.unwrap_or_default()
|
||||
}))
|
||||
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
|
@ -18,8 +18,9 @@ use crate::{
|
|||
|
||||
use crate::completion::CompletionRequest;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use serde_json::json;
|
||||
|
||||
use super::completion::CompletionModel;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::StreamingCompletionResponse;
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
json_utils::merge,
|
||||
streaming::{StreamingCompletionModel, StreamingResult},
|
||||
streaming::StreamingCompletionModel,
|
||||
};
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream_tokens": true}));
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::providers::xai::completion::CompletionModel;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use serde_json::json;
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -11,59 +11,150 @@
|
|||
|
||||
use crate::agent::Agent;
|
||||
use crate::completion::{
|
||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
|
||||
CompletionResponse, Message,
|
||||
};
|
||||
use crate::message::{AssistantContent, ToolCall, ToolFunction};
|
||||
use crate::OneOrMany;
|
||||
use futures::{Stream, StreamExt};
|
||||
use std::boxed::Box;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug)]
|
||||
pub enum StreamingChoice {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RawStreamingChoice<R: Clone> {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
||||
/// A tool call response chunk
|
||||
ToolCall(String, String, serde_json::Value),
|
||||
|
||||
/// The final response object, must be yielded if you want the
|
||||
/// `response` field to be populated on the `StreamingCompletionResponse`
|
||||
FinalResponse(R),
|
||||
}
|
||||
|
||||
impl Display for StreamingChoice {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamingChoice::Message(text) => write!(f, "{}", text),
|
||||
StreamingChoice::ToolCall(name, id, params) => {
|
||||
write!(f, "Tool call: {} {} {:?}", name, id, params)
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub type StreamingResult<R> =
|
||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub type StreamingResult<R> =
|
||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
|
||||
|
||||
/// The response from a streaming completion request;
|
||||
/// message and response are populated at the end of the
|
||||
/// `inner` stream.
|
||||
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
||||
inner: StreamingResult<R>,
|
||||
text: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
/// The final aggregated message from the stream
|
||||
/// contains all text and tool calls generated
|
||||
pub choice: OneOrMany<AssistantContent>,
|
||||
/// The final response from the stream, may be `None`
|
||||
/// if the provider didn't yield it during the stream
|
||||
pub response: Option<R>,
|
||||
}
|
||||
|
||||
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
||||
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||
Self {
|
||||
inner,
|
||||
text: "".to_string(),
|
||||
tool_calls: vec![],
|
||||
choice: OneOrMany::one(AssistantContent::text("")),
|
||||
response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub type StreamingResult =
|
||||
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
|
||||
impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
|
||||
fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
|
||||
CompletionResponse {
|
||||
choice: value.choice,
|
||||
raw_response: value.response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
|
||||
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||
type Item = Result<AssistantContent, CompletionError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let stream = self.get_mut();
|
||||
|
||||
match stream.inner.as_mut().poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => {
|
||||
// This is run at the end of the inner stream to collect all tokens into
|
||||
// a single unified `Message`.
|
||||
let mut choice = vec![];
|
||||
|
||||
stream.tool_calls.iter().for_each(|tc| {
|
||||
choice.push(AssistantContent::ToolCall(tc.clone()));
|
||||
});
|
||||
|
||||
// This is required to ensure there's always at least one item in the content
|
||||
if choice.is_empty() || !stream.text.is_empty() {
|
||||
choice.insert(0, AssistantContent::text(stream.text.clone()));
|
||||
}
|
||||
|
||||
stream.choice = OneOrMany::many(choice)
|
||||
.expect("There should be at least one assistant message");
|
||||
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||
RawStreamingChoice::Message(text) => {
|
||||
// Forward the streaming tokens to the outer stream
|
||||
// and concat the text together
|
||||
stream.text = format!("{}{}", stream.text, text.clone());
|
||||
Poll::Ready(Some(Ok(AssistantContent::text(text))))
|
||||
}
|
||||
RawStreamingChoice::ToolCall(id, name, args) => {
|
||||
// Keep track of each tool call to aggregate the final message later
|
||||
// and pass it to the outer stream
|
||||
stream.tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
function: ToolFunction {
|
||||
name: name.clone(),
|
||||
arguments: args.clone(),
|
||||
},
|
||||
});
|
||||
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
|
||||
}
|
||||
RawStreamingChoice::FinalResponse(response) => {
|
||||
// Set the final response field and return the next item in the stream
|
||||
stream.response = Some(response);
|
||||
|
||||
stream.poll_next_unpin(cx)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for high-level streaming prompt interface
|
||||
pub trait StreamingPrompt: Send + Sync {
|
||||
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
|
||||
/// Stream a simple prompt to the model
|
||||
fn stream_prompt(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||
}
|
||||
|
||||
/// Trait for high-level streaming chat interface
|
||||
pub trait StreamingChat: Send + Sync {
|
||||
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
|
||||
/// Stream a chat with history to the model
|
||||
fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||
}
|
||||
|
||||
/// Trait for low-level streaming completion interface
|
||||
|
@ -78,29 +169,35 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
|||
|
||||
/// Trait defining a streaming completion model
|
||||
pub trait StreamingCompletionModel: CompletionModel {
|
||||
type StreamingResponse: Clone + Unpin;
|
||||
/// Stream a completion response for the given request
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<
|
||||
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
|
||||
>;
|
||||
}
|
||||
|
||||
/// helper function to stream a completion request to stdout
|
||||
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||
agent: Agent<M>,
|
||||
stream: &mut StreamingResult,
|
||||
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
print!("Response: ");
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(StreamingChoice::Message(text)) => {
|
||||
print!("{}", text);
|
||||
Ok(AssistantContent::Text(text)) => {
|
||||
print!("{}", text.text);
|
||||
std::io::Write::flush(&mut std::io::stdout())?;
|
||||
}
|
||||
Ok(StreamingChoice::ToolCall(name, _, params)) => {
|
||||
Ok(AssistantContent::ToolCall(tool_call)) => {
|
||||
let res = agent
|
||||
.tools
|
||||
.call(&name, params.to_string())
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
println!("\nResult: {}", res);
|
||||
|
@ -111,6 +208,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!(); // New line after streaming completes
|
||||
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in New Issue