feat: openrouter streaming

This commit is contained in:
yavens 2025-04-15 16:15:48 -04:00
parent b7c80c6d19
commit c432a6bdcd
5 changed files with 506 additions and 163 deletions

View File

@ -11,6 +11,7 @@ use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use tracing::debug;
// ================================================================
// OpenAI Completion Streaming API
@ -26,10 +27,11 @@ pub struct StreamingFunction {
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingToolCall {
pub index: usize,
pub id: Option<String>,
pub function: StreamingFunction,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
struct StreamingDelta {
#[serde(default)]
content: Option<String>,
@ -37,12 +39,12 @@ struct StreamingDelta {
tool_calls: Vec<StreamingToolCall>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
struct StreamingChoice {
delta: StreamingDelta,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>,
usage: Option<Usage>,
@ -94,7 +96,7 @@ pub async fn send_compatible_streaming_request(
};
let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
@ -139,6 +141,8 @@ pub async fn send_compatible_streaming_request(
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
let Ok(data) = data else {
let err = data.unwrap_err();
debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
continue;
};
@ -150,35 +154,39 @@ pub async fn send_compatible_streaming_request(
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
// Start of tool call
// name: Some(String)
// arguments: None
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
let id = tool_call.id.clone().unwrap_or("".to_string());
calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
debug!("Partial tool call received but tool call was never started.");
continue;
};
let new_arguments = &tool_call.function.arguments;
let arguments = format!("{}{}", arguments, new_arguments);
calls.insert(tool_call.index, (name.clone(), arguments));
calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
}
// Entire tool call
else {
let name = function.name.unwrap();
let id = tool_call.id.clone().unwrap_or("".to_string());
let name = function.name.expect("function name should be present for complete tool call");
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
debug!("Couldn't serialize '{}' as a json value", arguments);
continue;
};
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
yield Ok(streaming::RawStreamingChoice::ToolCall(id, name, arguments))
}
}
}
@ -195,12 +203,14 @@ pub async fn send_compatible_streaming_request(
}
}
for (_, (name, arguments)) in calls {
for (_, (id, name, arguments)) in calls {
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
println!("{id} {name}");
yield Ok(RawStreamingChoice::ToolCall(id, name, arguments))
}
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {

View File

@ -0,0 +1,125 @@
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::completion::CompletionModel;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub(crate) message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

View File

@ -1,147 +1,16 @@
//! OpenRouter Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
use serde::Deserialize;
use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
providers::openai::Message,
OneOrMany,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
use super::openai::AssistantContent;
// ================================================================
// Main openrouter Client
// ================================================================
const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenRouter client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENROUTER_API_BASE_URL)
}
/// Create a new OpenRouter client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenRouter reqwest client should build"),
}
}
/// Create a new openrouter client from the `openrouter_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
Self::new(&api_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the openrouter client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openrouter::{Client, self};
///
/// // Initialize the Eternal client
/// let openrouter = Client::new("your-openrouter-api-key");
///
/// let agent = openrouter.agent(openrouter::LLAMA_3_1_8B)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
use crate::providers::openai::AssistantContent;
// ================================================================
// OpenRouter Completion API
@ -241,7 +110,7 @@ pub struct Choice {
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub(crate) client: Client,
/// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
pub model: String,
}
@ -253,16 +122,11 @@ impl CompletionModel {
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
pub(crate) fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
@ -292,16 +156,30 @@ impl completion::CompletionModel for CompletionModel {
"temperature": completion_request.temperature,
});
let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
};
Ok(request)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;

View File

@ -0,0 +1,17 @@
//! OpenRouter Inference API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openrouter;
//!
//! let client = openrouter::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_8b = client.completion_model(openrouter::LLAMA_3_1_8B);
//! ```
pub mod client;
pub mod completion;
pub mod streaming;
pub use client::*;
pub use completion::*;

View File

@ -0,0 +1,313 @@
use std::collections::HashMap;
use crate::{
json_utils,
message::{ToolCall, ToolFunction},
streaming::{self},
};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use crate::{
completion::{CompletionError, CompletionRequest},
streaming::StreamingCompletionModel,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingCompletionResponse {
pub id: String,
pub choices: Vec<StreamingChoice>,
pub created: u64,
pub model: String,
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponseUsage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct StreamingChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Value>,
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<MessageResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<DeltaResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ErrorResponse>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageResponse {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Value>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolFunction {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenRouterToolCall {
pub index: usize,
pub id: Option<String>,
pub r#type: Option<String>,
pub function: OpenRouterToolFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ResponseUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, Value>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DeltaResponse {
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Vec<OpenRouterToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub native_finish_reason: Option<String>,
}
#[derive(Clone)]
pub struct FinalCompletionResponse {
pub usage: ResponseUsage,
}
impl StreamingCompletionModel for super::CompletionModel {
type StreamingResponse = FinalCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(completion_request)?;
let request = json_utils::merge(request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_streaming_request(builder).await
}
}
pub async fn send_streaming_request(
request_builder: RequestBuilder,
) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
// Handle OpenAI Compatible SSE chunks
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut tool_calls = HashMap::new();
let mut partial_line = String::new();
let mut final_usage = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let mut line = line.to_string();
// Skip empty lines and processing messages, as well as [DONE] (might be useful though)
if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
continue;
}
// Handle data: prefix
line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
// If line starts with { but doesn't end with }, it's a partial JSON
if line.starts_with('{') && !line.ends_with('}') {
partial_line = line;
continue;
}
// If we have a partial line and this line ends with }, complete it
if !partial_line.is_empty() {
if line.ends_with('}') {
partial_line.push_str(&line);
line = partial_line;
partial_line = String::new();
} else {
partial_line.push_str(&line);
continue;
}
}
let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
Ok(data) => data,
Err(_) => {
continue;
}
};
let choice = data.choices.first().expect("Should have at least one choice");
// TODO this has to handle outputs like this:
// [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
// [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
// [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
if let Some(delta) = &choice.delta {
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let index = tool_call.index;
// Get or create tool call entry
let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
id: String::new(),
function: ToolFunction {
name: String::new(),
arguments: serde_json::Value::Null,
},
});
// Update fields if present
if let Some(id) = &tool_call.id {
if !id.is_empty() {
existing_tool_call.id = id.clone();
}
}
if let Some(name) = &tool_call.function.name {
if !name.is_empty() {
existing_tool_call.function.name = name.clone();
}
}
if let Some(chunk) = &tool_call.function.arguments {
// Convert current arguments to string if needed
let current_args = match &existing_tool_call.function.arguments {
serde_json::Value::Null => String::new(),
serde_json::Value::String(s) => s.clone(),
v => v.to_string(),
};
// Concatenate the new chunk
let combined = format!("{}{}", current_args, chunk);
// Try to parse as JSON if it looks complete
if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
match serde_json::from_str(&combined) {
Ok(parsed) => existing_tool_call.function.arguments = parsed,
Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
}
} else {
existing_tool_call.function.arguments = serde_json::Value::String(combined);
}
}
}
}
if let Some(content) = &delta.content {
if !content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
}
if let Some(usage) = data.usage {
final_usage = Some(usage);
}
}
// Handle message format
if let Some(message) = &choice.message {
if !message.tool_calls.is_empty() {
for tool_call in &message.tool_calls {
let name = tool_call.function.name.clone();
let id = tool_call.id.clone();
let arguments = if let Some(args) = &tool_call.function.arguments {
// Try to parse the string as JSON, fallback to string value
match serde_json::from_str(args) {
Ok(v) => v,
Err(_) => serde_json::Value::String(args.to_string()),
}
} else {
serde_json::Value::Null
};
let index = tool_call.index;
tool_calls.insert(index, ToolCall{
id: id.unwrap_or_default(),
function: ToolFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
}
if !message.content.is_empty() {
yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
}
}
}
}
for (_, tool_call) in tool_calls.into_iter() {
yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call.function.name, tool_call.id, tool_call.function.arguments));
}
yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
usage: final_usage.unwrap_or_default()
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}