mirror of https://github.com/0xplaygrounds/rig
feat: start refactoring streaming api
This commit is contained in:
parent
92c91d23c3
commit
fcbe648f77
|
@ -8,7 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::message::MessageError;
|
use crate::message::MessageError;
|
||||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
@ -191,12 +191,12 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
fn handle_event(
|
fn handle_event(
|
||||||
event: &StreamingEvent,
|
event: &StreamingEvent,
|
||||||
current_tool_call: &mut Option<ToolCallState>,
|
current_tool_call: &mut Option<ToolCallState>,
|
||||||
) -> Option<Result<StreamingChoice, CompletionError>> {
|
) -> Option<Result<RawStreamingChoice, CompletionError>> {
|
||||||
match event {
|
match event {
|
||||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||||
ContentDelta::TextDelta { text } => {
|
ContentDelta::TextDelta { text } => {
|
||||||
if current_tool_call.is_none() {
|
if current_tool_call.is_none() {
|
||||||
return Some(Ok(StreamingChoice::Message(text.clone())));
|
return Some(Ok(RawStreamingChoice::Message(text.clone())));
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ fn handle_event(
|
||||||
&tool_call.input_json
|
&tool_call.input_json
|
||||||
};
|
};
|
||||||
match serde_json::from_str(json_str) {
|
match serde_json::from_str(json_str) {
|
||||||
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
|
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
|
||||||
tool_call.name,
|
tool_call.name,
|
||||||
tool_call.id,
|
tool_call.id,
|
||||||
json_value,
|
json_value,
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
// Azure OpenAI Streaming API
|
// Azure OpenAI Streaming API
|
||||||
// -----------------------------------------------------
|
// -----------------------------------------------------
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type Response = openai::StreamingCompletionResponse;
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self
|
let builder = self
|
||||||
.client
|
.client
|
||||||
|
|
|
@ -10,8 +10,9 @@
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||||
extractor::ExtractorBuilder,
|
extractor::ExtractorBuilder,
|
||||||
|
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||||
|
type Response = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/v1/chat/completions").json(&request);
|
let builder = self.client.post("/v1/chat/completions").json(&request);
|
||||||
send_compatible_streaming_request(builder).await
|
send_compatible_streaming_request(builder).await
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
use super::openai;
|
use super::openai;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type Response = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -74,9 +74,9 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
|
|
||||||
match choice.content.parts.first() {
|
match choice.content.parts.first() {
|
||||||
super::completion::gemini_api_types::Part::Text(text)
|
super::completion::gemini_api_types::Part::Text(text)
|
||||||
=> yield Ok(streaming::StreamingChoice::Message(text)),
|
=> yield Ok(streaming::RawStreamingChoice::Message(text)),
|
||||||
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
||||||
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||||
_ => panic!("Unsupported response type with streaming.")
|
_ => panic!("Unsupported response type with streaming.")
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,8 @@
|
||||||
//! ```
|
//! ```
|
||||||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::providers::openai;
|
||||||
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type Response = openai::StreamingCompletionResponse;
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||||
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type Response = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
merge_inplace(&mut request, json!({"stream": true}));
|
merge_inplace(
|
||||||
|
&mut request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
@ -526,8 +530,10 @@ mod image_generation {
|
||||||
// ======================================
|
// ======================================
|
||||||
// Hyperbolic Audio Generation API
|
// Hyperbolic Audio Generation API
|
||||||
// ======================================
|
// ======================================
|
||||||
|
use crate::providers::openai;
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
pub use audio_generation::*;
|
pub use audio_generation::*;
|
||||||
|
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
mod audio_generation {
|
mod audio_generation {
|
||||||
use super::{ApiResponse, Client};
|
use super::{ApiResponse, Client};
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -24,6 +24,7 @@ use serde_json::{json, Value};
|
||||||
use std::string::FromUtf8Error;
|
use std::string::FromUtf8Error;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing;
|
use tracing;
|
||||||
|
use crate::providers::openai;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum MiraError {
|
pub enum MiraError {
|
||||||
|
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type Response = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type Response = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -47,7 +47,7 @@ use crate::{
|
||||||
extractor::ExtractorBuilder,
|
extractor::ExtractorBuilder,
|
||||||
json_utils, message,
|
json_utils, message,
|
||||||
message::{ImageDetail, Text},
|
message::{ImageDetail, Text},
|
||||||
Embed, OneOrMany,
|
streaming, Embed, OneOrMany,
|
||||||
};
|
};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
@ -405,8 +405,30 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
#[serde(default)]
|
||||||
|
pub done_reason: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total_duration: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub load_duration: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub prompt_eval_count: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub prompt_eval_duration: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub eval_count: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub eval_duration: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
type Response = StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request_payload = self.create_completion_request(request)?;
|
let mut request_payload = self.create_completion_request(request)?;
|
||||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||||
|
|
||||||
|
@ -426,7 +448,8 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
return Err(CompletionError::ProviderError(err_text));
|
return Err(CompletionError::ProviderError(err_text));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let mut
|
||||||
|
let stream = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
let chunk = match chunk_result {
|
let chunk = match chunk_result {
|
||||||
|
@ -456,13 +479,13 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
match response.message {
|
match response.message {
|
||||||
Message::Assistant{ content, tool_calls, .. } => {
|
Message::Assistant{ content, tool_calls, .. } => {
|
||||||
if !content.is_empty() {
|
if !content.is_empty() {
|
||||||
yield Ok(StreamingChoice::Message(content))
|
yield Ok(RawStreamingChoice::Message(content))
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool_call in tool_calls.iter() {
|
for tool_call in tool_calls.iter() {
|
||||||
let function = tool_call.function.clone();
|
let function = tool_call.function.clone();
|
||||||
|
|
||||||
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -471,7 +494,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ use super::completion::CompletionModel;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils;
|
use crate::json_utils;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai::Usage;
|
||||||
use crate::streaming;
|
use crate::streaming;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
|
@ -10,6 +11,7 @@ use reqwest::RequestBuilder;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tokio::stream;
|
||||||
|
|
||||||
// ================================================================
|
// ================================================================
|
||||||
// OpenAI Completion Streaming API
|
// OpenAI Completion Streaming API
|
||||||
|
@ -42,15 +44,21 @@ struct StreamingChoice {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct StreamingCompletionResponse {
|
struct StreamingCompletionChunk {
|
||||||
choices: Vec<StreamingChoice>,
|
choices: Vec<StreamingChoice>,
|
||||||
|
usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
usage: Option<Usage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type Response = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
||||||
|
@ -61,7 +69,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
|
|
||||||
pub async fn send_compatible_streaming_request(
|
pub async fn send_compatible_streaming_request(
|
||||||
request_builder: RequestBuilder,
|
request_builder: RequestBuilder,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
|
||||||
let response = request_builder.send().await?;
|
let response = request_builder.send().await?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
|
@ -73,7 +81,7 @@ pub async fn send_compatible_streaming_request(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle OpenAI Compatible SSE chunks
|
// Handle OpenAI Compatible SSE chunks
|
||||||
Ok(Box::pin(stream! {
|
let inner = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
|
||||||
let mut partial_data = None;
|
let mut partial_data = None;
|
||||||
|
@ -121,7 +129,7 @@ pub async fn send_compatible_streaming_request(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
|
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
|
||||||
|
|
||||||
let Ok(data) = data else {
|
let Ok(data) = data else {
|
||||||
continue;
|
continue;
|
||||||
|
@ -162,13 +170,17 @@ pub async fn send_compatible_streaming_request(
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(content) = &choice.delta.content {
|
if let Some(content) = &choice.delta.content {
|
||||||
yield Ok(streaming::StreamingChoice::Message(content.clone()))
|
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if &data.usage.is_some() {
|
||||||
|
usage = data.usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -178,7 +190,11 @@ pub async fn send_compatible_streaming_request(
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(
|
||||||
|
inner,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,14 +13,28 @@ use crate::agent::Agent;
|
||||||
use crate::completion::{
|
use crate::completion::{
|
||||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
||||||
};
|
};
|
||||||
|
use crate::message::AssistantContent;
|
||||||
use futures::{Stream, StreamExt};
|
use futures::{Stream, StreamExt};
|
||||||
use std::boxed::Box;
|
use std::boxed::Box;
|
||||||
use std::fmt::{Display, Formatter};
|
use std::fmt::{Display, Formatter};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
/// Enum representing a streaming chunk from the model
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum RawStreamingChoice<R> {
|
||||||
|
/// A text chunk from a message response
|
||||||
|
Message(String),
|
||||||
|
|
||||||
|
/// A tool call response chunk
|
||||||
|
ToolCall(String, String, serde_json::Value),
|
||||||
|
|
||||||
|
/// The final response object
|
||||||
|
FinalResponse(R),
|
||||||
|
}
|
||||||
|
|
||||||
/// Enum representing a streaming chunk from the model
|
/// Enum representing a streaming chunk from the model
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum StreamingChoice {
|
pub enum StreamingChoice {
|
||||||
/// A text chunk from a message response
|
/// A text chunk from a message response
|
||||||
Message(String),
|
Message(String),
|
||||||
|
@ -41,29 +55,87 @@ impl Display for StreamingChoice {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
pub type StreamingResult =
|
pub type StreamingResult<R> =
|
||||||
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||||
|
|
||||||
#[cfg(target_arch = "wasm32")]
|
#[cfg(target_arch = "wasm32")]
|
||||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
|
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
|
||||||
|
|
||||||
|
pub struct StreamingCompletionResponse<R> {
|
||||||
|
inner: StreamingResult<R>,
|
||||||
|
text: String,
|
||||||
|
tool_calls: Vec<(String, String, serde_json::Value)>,
|
||||||
|
pub message: Message,
|
||||||
|
pub response: Option<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> StreamingCompletionResponse<R> {
|
||||||
|
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
text: "".to_string(),
|
||||||
|
tool_calls: vec![],
|
||||||
|
message: Message::assistant(""),
|
||||||
|
response: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> Stream for StreamingCompletionResponse<R> {
|
||||||
|
type Item = Result<StreamingChoice, CompletionError>;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match self.inner.poll_next(cx) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
|
||||||
|
Poll::Ready(None) => {
|
||||||
|
let content = vec![AssistantContent::text(self.text.clone())];
|
||||||
|
|
||||||
|
self.tool_calls
|
||||||
|
.iter()
|
||||||
|
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
|
||||||
|
|
||||||
|
self.message = Message::Assistant {
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||||
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||||
|
RawStreamingChoice::Message(text) => {
|
||||||
|
self.text = format!("{}{}", self.text, text);
|
||||||
|
Poll::Ready(Some(Ok(choice.clone())))
|
||||||
|
}
|
||||||
|
RawStreamingChoice::ToolCall(name, description, args) => {
|
||||||
|
self.tool_calls
|
||||||
|
.push((name, description, args));
|
||||||
|
Poll::Ready(Some(Ok(choice.clone())))
|
||||||
|
}
|
||||||
|
RawStreamingChoice::FinalResponse(response) => {
|
||||||
|
self.response = Some(response);
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming prompt interface
|
/// Trait for high-level streaming prompt interface
|
||||||
pub trait StreamingPrompt: Send + Sync {
|
pub trait StreamingPrompt<R>: Send + Sync {
|
||||||
/// Stream a simple prompt to the model
|
/// Stream a simple prompt to the model
|
||||||
fn stream_prompt(
|
fn stream_prompt(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming chat interface
|
/// Trait for high-level streaming chat interface
|
||||||
pub trait StreamingChat: Send + Sync {
|
pub trait StreamingChat<R>: Send + Sync {
|
||||||
/// Stream a chat with history to the model
|
/// Stream a chat with history to the model
|
||||||
fn stream_chat(
|
fn stream_chat(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
chat_history: Vec<Message>,
|
chat_history: Vec<Message>,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for low-level streaming completion interface
|
/// Trait for low-level streaming completion interface
|
||||||
|
@ -78,17 +150,18 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
||||||
|
|
||||||
/// Trait defining a streaming completion model
|
/// Trait defining a streaming completion model
|
||||||
pub trait StreamingCompletionModel: CompletionModel {
|
pub trait StreamingCompletionModel: CompletionModel {
|
||||||
|
type Response;
|
||||||
/// Stream a completion response for the given request
|
/// Stream a completion response for the given request
|
||||||
fn stream(
|
fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// helper function to stream a completion request to stdout
|
/// helper function to stream a completion request to stdout
|
||||||
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
|
||||||
agent: Agent<M>,
|
agent: Agent<M>,
|
||||||
stream: &mut StreamingResult,
|
stream: &mut StreamingCompletionResponse<R>,
|
||||||
) -> Result<(), std::io::Error> {
|
) -> Result<(), std::io::Error> {
|
||||||
print!("Response: ");
|
print!("Response: ");
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
|
@ -111,6 +184,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(); // New line after streaming completes
|
println!(); // New line after streaming completes
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
Loading…
Reference in New Issue