mirror of https://github.com/0xplaygrounds/rig
feat: openrouter streaming
This commit is contained in:
parent
b7c80c6d19
commit
c432a6bdcd
|
@ -11,6 +11,7 @@ use reqwest::RequestBuilder;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tracing::debug;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Completion Streaming API
|
||||
|
@ -26,10 +27,11 @@ pub struct StreamingFunction {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StreamingToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: StreamingFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
|
@ -37,12 +39,12 @@ struct StreamingDelta {
|
|||
tool_calls: Vec<StreamingToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingChoice {
|
||||
delta: StreamingDelta,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamingCompletionChunk {
|
||||
choices: Vec<StreamingChoice>,
|
||||
usage: Option<Usage>,
|
||||
|
@ -94,7 +96,7 @@ pub async fn send_compatible_streaming_request(
|
|||
};
|
||||
|
||||
let mut partial_data = None;
|
||||
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
||||
let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
|
@ -139,6 +141,8 @@ pub async fn send_compatible_streaming_request(
|
|||
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
|
||||
|
||||
let Ok(data) = data else {
|
||||
let err = data.unwrap_err();
|
||||
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
|
||||
continue;
|
||||
};
|
||||
|
||||
|
@ -150,35 +154,39 @@ pub async fn send_compatible_streaming_request(
|
|||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let function = tool_call.function.clone();
|
||||
|
||||
// Start of tool call
|
||||
// name: Some(String)
|
||||
// arguments: None
|
||||
if function.name.is_some() && function.arguments.is_empty() {
|
||||
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
||||
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||
|
||||
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
|
||||
}
|
||||
// Part of tool call
|
||||
// name: None
|
||||
// arguments: Some(String)
|
||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
||||
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
|
||||
debug!("Partial tool call received but tool call was never started.");
|
||||
continue;
|
||||
};
|
||||
|
||||
let new_arguments = &tool_call.function.arguments;
|
||||
let arguments = format!("{}{}", arguments, new_arguments);
|
||||
|
||||
calls.insert(tool_call.index, (name.clone(), arguments));
|
||||
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
|
||||
}
|
||||
// Entire tool call
|
||||
else {
|
||||
let name = function.name.unwrap();
|
||||
let id = tool_call.id.clone().unwrap_or("".to_string());
|
||||
let name = function.name.expect("function name should be present for complete tool call");
|
||||
let arguments = function.arguments;
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
debug!("Couldn't serialize '{}' as a json value", arguments);
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -195,12 +203,14 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
}
|
||||
|
||||
for (_, (name, arguments)) in calls {
|
||||
for (_, (id, name, arguments)) in calls {
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
println!("{id} {name}");
|
||||
|
||||
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
|
||||
}
|
||||
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::completion::CompletionModel;
|
||||
|
||||
// ================================================================
|
||||
// Main openrouter Client
|
||||
// ================================================================
|
||||
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenRouter client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenRouter client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("OpenRouter reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the openrouter client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the Eternal client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
pub(crate) message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,147 +1,16 @@
|
|||
//! OpenRouter Inference API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::openrouter;
|
||||
//!
|
||||
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
//! ```
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
|
||||
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
json_utils,
|
||||
providers::openai::Message,
|
||||
OneOrMany,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::openai::AssistantContent;
|
||||
|
||||
// ================================================================
|
||||
// Main openrouter Client
|
||||
// ================================================================
|
||||
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenRouter client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenRouter client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("OpenRouter reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the openrouter client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openrouter::{Client, self};
|
||||
///
|
||||
/// // Initialize the Eternal client
|
||||
/// let openrouter = Client::new("your-openrouter-api-key");
|
||||
///
|
||||
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
use crate::providers::openai::AssistantContent;
|
||||
|
||||
// ================================================================
|
||||
// OpenRouter Completion API
|
||||
|
@ -241,7 +110,7 @@ pub struct Choice {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub(crate) client: Client,
|
||||
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
|
||||
pub model: String,
|
||||
}
|
||||
|
@ -253,16 +122,11 @@ impl CompletionModel {
|
|||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
|
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
|
|||
"temperature": completion_request.temperature,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(
|
||||
&if let Some(params) = completion_request.additional_params {
|
||||
let request = if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
//! OpenRouter Inference API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::openrouter;
|
||||
//!
|
||||
//! let client = openrouter::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
|
||||
//! ```
|
||||
|
||||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod streaming;
|
||||
|
||||
pub use client::*;
|
||||
pub use completion::*;
|
|
@ -0,0 +1,313 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
json_utils,
|
||||
message::{ToolCall, ToolFunction},
|
||||
streaming::{self},
|
||||
};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
streaming::StreamingCompletionModel,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub id: String,
|
||||
pub choices: Vec<StreamingChoice>,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub object: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<ResponseUsage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct StreamingChoice {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub native_finish_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<Value>,
|
||||
pub index: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<MessageResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub delta: Option<DeltaResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<ErrorResponse>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageResponse {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refusal: Option<Value>,
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OpenRouterToolFunction {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OpenRouterToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub r#type: Option<String>,
|
||||
pub function: OpenRouterToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ResponseUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ErrorResponse {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DeltaResponse {
|
||||
pub role: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<OpenRouterToolCall>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub native_finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FinalCompletionResponse {
|
||||
pub usage: ResponseUsage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for super::CompletionModel {
|
||||
type StreamingResponse = FinalCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let request = json_utils::merge(request, json!({"stream": true}));
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
send_streaming_request(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_streaming_request(
|
||||
request_builder: RequestBuilder,
|
||||
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
|
||||
let response = request_builder.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
// Handle OpenAI Compatible SSE chunks
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_calls = HashMap::new();
|
||||
let mut partial_line = String::new();
|
||||
let mut final_usage = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::from(e));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match String::from_utf8(chunk.to_vec()) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
for line in text.lines() {
|
||||
let mut line = line.to_string();
|
||||
|
||||
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
|
||||
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle data: prefix
|
||||
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
|
||||
|
||||
// If line starts with { but doesn't end with }, it's a partial JSON
|
||||
if line.starts_with('{') && !line.ends_with('}') {
|
||||
partial_line = line;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we have a partial line and this line ends with }, complete it
|
||||
if !partial_line.is_empty() {
|
||||
if line.ends_with('}') {
|
||||
partial_line.push_str(&line);
|
||||
line = partial_line;
|
||||
partial_line = String::new();
|
||||
} else {
|
||||
partial_line.push_str(&line);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let choice = data.choices.first().expect("Should have at least one choice");
|
||||
|
||||
// TODO this has to handle outputs like this:
|
||||
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
|
||||
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
|
||||
if let Some(delta) = &choice.delta {
|
||||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let index = tool_call.index;
|
||||
|
||||
// Get or create tool call entry
|
||||
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
|
||||
id: String::new(),
|
||||
function: ToolFunction {
|
||||
name: String::new(),
|
||||
arguments: serde_json::Value::Null,
|
||||
},
|
||||
});
|
||||
|
||||
// Update fields if present
|
||||
if let Some(id) = &tool_call.id {
|
||||
if !id.is_empty() {
|
||||
existing_tool_call.id = id.clone();
|
||||
}
|
||||
}
|
||||
if let Some(name) = &tool_call.function.name {
|
||||
if !name.is_empty() {
|
||||
existing_tool_call.function.name = name.clone();
|
||||
}
|
||||
}
|
||||
if let Some(chunk) = &tool_call.function.arguments {
|
||||
// Convert current arguments to string if needed
|
||||
let current_args = match &existing_tool_call.function.arguments {
|
||||
serde_json::Value::Null => String::new(),
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
v => v.to_string(),
|
||||
};
|
||||
|
||||
// Concatenate the new chunk
|
||||
let combined = format!("{}{}", current_args, chunk);
|
||||
|
||||
// Try to parse as JSON if it looks complete
|
||||
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
|
||||
match serde_json::from_str(&combined) {
|
||||
Ok(parsed) => existing_tool_call.function.arguments = parsed,
|
||||
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
|
||||
}
|
||||
} else {
|
||||
existing_tool_call.function.arguments = serde_json::Value::String(combined);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
if !content.is_empty() {
|
||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = data.usage {
|
||||
final_usage = Some(usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message format
|
||||
if let Some(message) = &choice.message {
|
||||
if !message.tool_calls.is_empty() {
|
||||
for tool_call in &message.tool_calls {
|
||||
let name = tool_call.function.name.clone();
|
||||
let id = tool_call.id.clone();
|
||||
let arguments = if let Some(args) = &tool_call.function.arguments {
|
||||
// Try to parse the string as JSON, fallback to string value
|
||||
match serde_json::from_str(args) {
|
||||
Ok(v) => v,
|
||||
Err(_) => serde_json::Value::String(args.to_string()),
|
||||
}
|
||||
} else {
|
||||
serde_json::Value::Null
|
||||
};
|
||||
let index = tool_call.index;
|
||||
|
||||
tool_calls.insert(index, ToolCall{
|
||||
id: id.unwrap_or_default(),
|
||||
function: ToolFunction {
|
||||
name: name.unwrap_or_default(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !message.content.is_empty() {
|
||||
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, tool_call) in tool_calls.into_iter() {
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
|
||||
}
|
||||
|
||||
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
|
||||
usage: final_usage.unwrap_or_default()
|
||||
}))
|
||||
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
Loading…
Reference in New Issue