fix: compiles + formatted

This commit is contained in:
yavens 2025-04-09 12:54:55 -04:00
parent fcbe648f77
commit 86b84c82fb
20 changed files with 175 additions and 103 deletions

View File

@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> {
.await?; .await?;
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
Ok(()) Ok(())
} }

View File

@ -110,23 +110,20 @@ use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt}; use futures::{stream, StreamExt, TryStreamExt};
use crate::streaming::StreamingCompletionResponse;
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use crate::{ use crate::{
completion::{ completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError, Message, Prompt, PromptError,
}, },
message::AssistantContent, message::AssistantContent,
streaming::{ streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet}, tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn}, vector_store::{VectorStoreError, VectorStoreIndexDyn},
}; };
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools. /// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted. /// All context documents and tools are always provided to the agent when prompted.
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
} }
} }
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> { impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> { async fn stream_prompt(
&self,
prompt: &str,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_chat(prompt, vec![]).await self.stream_chat(prompt, vec![]).await
} }
} }
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> { impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
async fn stream_chat( async fn stream_chat(
&self, &self,
prompt: &str, prompt: &str,
chat_history: Vec<Message>, chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_completion(prompt, chat_history) self.stream_completion(prompt, chat_history)
.await? .await?
.stream() .stream()

View File

@ -67,7 +67,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::OneOrMany; use crate::OneOrMany;
use crate::{ use crate::{
json_utils, json_utils,
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> { impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
/// Stream the completion request /// Stream the completion request
pub async fn stream(self) -> Result<StreamingResult, CompletionError> { pub async fn stream(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
let model = self.model.clone(); let model = self.model.clone();
model.stream(self.build()).await model.stream(self.build()).await
} }

View File

@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::message::MessageError; use crate::message::MessageError;
use crate::streaming;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -61,7 +62,7 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>, pub stop_sequence: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Clone)]
pub struct PartialUsage { pub struct PartialUsage {
pub output_tokens: usize, pub output_tokens: usize,
#[serde(default)] #[serde(default)]
@ -75,11 +76,18 @@ struct ToolCallState {
input_json: String, input_json: String,
} }
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: PartialUsage,
}
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let max_tokens = if let Some(tokens) = completion_request.max_tokens { let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens tokens
} else if let Some(tokens) = self.default_max_tokens { } else if let Some(tokens) = self.default_max_tokens {
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
// Use our SSE decoder to directly handle Server-Sent Events format // Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response); let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! { let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
let mut current_tool_call: Option<ToolCallState> = None; let mut current_tool_call: Option<ToolCallState> = None;
let mut sse_stream = Box::pin(sse_stream); let mut sse_stream = Box::pin(sse_stream);
let mut input_tokens = 0;
while let Some(sse_result) = sse_stream.next().await { while let Some(sse_result) = sse_stream.next().await {
match sse_result { match sse_result {
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
// Parse the SSE data as a StreamingEvent // Parse the SSE data as a StreamingEvent
match serde_json::from_str::<StreamingEvent>(&sse.data) { match serde_json::from_str::<StreamingEvent>(&sse.data) {
Ok(event) => { Ok(event) => {
match &event {
StreamingEvent::MessageStart { message } => {
input_tokens = message.usage.input_tokens;
},
StreamingEvent::MessageDelta { delta, usage } => {
if delta.stop_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: PartialUsage {
output_tokens: usage.output_tokens,
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
}
}))
}
}
_ => {}
}
if let Some(result) = handle_event(&event, &mut current_tool_call) { if let Some(result) = handle_event(&event, &mut current_tool_call) {
yield result; yield result;
} }
@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel {
} }
} }
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }
fn handle_event( fn handle_event(
event: &StreamingEvent, event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>, current_tool_call: &mut Option<ToolCallState>,
) -> Option<Result<RawStreamingChoice, CompletionError>> { ) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
match event { match event {
StreamingEvent::ContentBlockDelta { delta, .. } => match delta { StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel {
// Azure OpenAI Streaming API // Azure OpenAI Streaming API
// ----------------------------------------------------- // -----------------------------------------------------
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge( request = merge(

View File

@ -12,7 +12,7 @@
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai; use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest}, completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder, extractor::ExtractorBuilder,
@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel {
} }
impl StreamingCompletionModel for DeepSeekCompletionModel { impl StreamingCompletionModel for DeepSeekCompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge( request = merge(

View File

@ -13,7 +13,7 @@
use super::openai; use super::openai;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge( request = merge(

View File

@ -609,7 +609,7 @@ pub mod gemini_api_types {
HarmCategoryCivicIntegrity, HarmCategoryCivicIntegrity,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct UsageMetadata { pub struct UsageMetadata {
pub prompt_token_count: i32, pub prompt_token_count: i32,

View File

@ -2,26 +2,34 @@ use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
use serde::Deserialize; use serde::Deserialize;
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
use crate::providers::gemini::completion::gemini_api_types::UsageMetadata;
use crate::{ use crate::{
completion::{CompletionError, CompletionRequest}, completion::{CompletionError, CompletionRequest},
streaming::{self, StreamingCompletionModel, StreamingResult}, streaming::{self, StreamingCompletionModel},
}; };
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct StreamGenerateContentResponse { pub struct StreamGenerateContentResponse {
/// Candidate responses from the model. /// Candidate responses from the model.
pub candidates: Vec<ContentCandidate>, pub candidates: Vec<ContentCandidate>,
pub model_version: Option<String>, pub model_version: Option<String>,
pub usage_metadata: UsageMetadata,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage_metadata: UsageMetadata,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = create_request_body(completion_request)?; let request = create_request_body(completion_request)?;
let response = self let response = self
@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel {
))); )));
} }
Ok(Box::pin(stream! { let stream = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel {
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)), => yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
_ => panic!("Unsupported response type with streaming.") _ => panic!("Unsupported response type with streaming.")
}; };
if choice.finish_reason.is_some() {
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage_metadata: data.usage_metadata,
}))
}
} }
} }
})) });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }

View File

@ -11,7 +11,7 @@
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse}; use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai; use crate::providers::openai;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge( request = merge(

View File

@ -1,9 +1,9 @@
use super::completion::CompletionModel; use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::StreamingCompletionModel;
use crate::{json_utils, streaming};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::convert::Infallible; use std::convert::Infallible;
@ -55,14 +55,19 @@ struct CompletionChunk {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_request_body(&completion_request)?; let mut request = self.create_request_body(&completion_request)?;
// Enable streaming // Enable streaming
merge_inplace(&mut request, json!({"stream": true})); merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
if let Some(ref params) = completion_request.additional_params { if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone()); merge_inplace(&mut request, params.clone());

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent}; use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
merge_inplace( merge_inplace(

View File

@ -8,8 +8,9 @@
//! //!
//! ``` //! ```
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -24,7 +25,6 @@ use serde_json::{json, Value};
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use thiserror::Error; use thiserror::Error;
use tracing; use tracing;
use crate::providers::openai;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum MiraError { pub enum MiraError {
@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -11,7 +11,7 @@
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?; let mut request = self.create_completion_request(request)?;
request = merge( request = merge(

View File

@ -39,7 +39,7 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2"); //! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ``` //! ```
use crate::json_utils::merge_inplace; use crate::json_utils::merge_inplace;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{ use crate::{
agent::AgentBuilder, agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest}, completion::{self, CompletionError, CompletionRequest},
@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel {
} }
} }
#[derive(Clone)]
pub struct StreamingCompletionResponse { pub struct StreamingCompletionResponse {
#[serde(default)]
pub done_reason: Option<String>, pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>, pub total_duration: Option<u64>,
#[serde(default)]
pub load_duration: Option<u64>, pub load_duration: Option<u64>,
#[serde(default)]
pub prompt_eval_count: Option<u64>, pub prompt_eval_count: Option<u64>,
#[serde(default)]
pub prompt_eval_duration: Option<u64>, pub prompt_eval_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u64>, pub eval_count: Option<u64>,
#[serde(default)]
pub eval_duration: Option<u64>, pub eval_duration: Option<u64>,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = StreamingCompletionResponse; type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request_payload = self.create_completion_request(request)?; let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true})); merge_inplace(&mut request_payload, json!({"stream": true}));
@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel {
return Err(CompletionError::ProviderError(err_text)); return Err(CompletionError::ProviderError(err_text));
} }
let mut
let stream = Box::pin(stream! { let stream = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel {
} }
} }
}); });
Ok(streaming::StreamingCompletionResponse::new(stream))
} }
} }

View File

@ -4,14 +4,13 @@ use crate::json_utils;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai::Usage; use crate::providers::openai::Usage;
use crate::streaming; use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use async_stream::stream; use async_stream::stream;
use futures::StreamExt; use futures::StreamExt;
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::stream;
// ================================================================ // ================================================================
// OpenAI Completion Streaming API // OpenAI Completion Streaming API
@ -47,18 +46,21 @@ struct StreamingChoice {
struct StreamingCompletionChunk { struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>, choices: Vec<StreamingChoice>,
usage: Option<Usage>, usage: Option<Usage>,
finish_reason: Option<String>,
} }
#[derive(Clone)]
pub struct StreamingCompletionResponse { pub struct StreamingCompletionResponse {
usage: Option<Usage>, pub usage: Option<Usage>,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type Response = StreamingCompletionResponse; type StreamingResponse = StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> { ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));
@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request(
yield Ok(streaming::RawStreamingChoice::Message(content.clone())) yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
} }
if &data.usage.is_some() { if data.finish_reason.is_some() {
usage = data.usage; yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: data.usage
}))
} }
} }
} }
@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request(
} }
}); });
Ok(streaming::StreamingCompletionResponse::new( Ok(streaming::StreamingCompletionResponse::new(inner))
inner,
))
} }

View File

@ -18,8 +18,9 @@ use crate::{
use crate::completion::CompletionRequest; use crate::completion::CompletionRequest;
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: completion::CompletionRequest, completion_request: completion::CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -1,18 +1,21 @@
use serde_json::json; use serde_json::json;
use super::completion::CompletionModel; use super::completion::CompletionModel;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::StreamingCompletionResponse;
use crate::{ use crate::{
completion::{CompletionError, CompletionRequest}, completion::{CompletionError, CompletionRequest},
json_utils::merge, json_utils::merge,
streaming::{StreamingCompletionModel, StreamingResult}, streaming::StreamingCompletionModel,
}; };
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream_tokens": true})); request = merge(request, json!({"stream_tokens": true}));

View File

@ -1,15 +1,17 @@
use crate::completion::{CompletionError, CompletionRequest}; use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge; use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request; use crate::providers::openai::send_compatible_streaming_request;
use crate::providers::xai::completion::CompletionModel; use crate::providers::xai::completion::CompletionModel;
use crate::streaming::{StreamingCompletionModel, StreamingResult}; use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use serde_json::json; use serde_json::json;
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream( async fn stream(
&self, &self,
completion_request: CompletionRequest, completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> { ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(request, json!({"stream": true}));

View File

@ -14,6 +14,7 @@ use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message, CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
}; };
use crate::message::AssistantContent; use crate::message::AssistantContent;
use crate::OneOrMany;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::boxed::Box; use std::boxed::Box;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
@ -23,7 +24,7 @@ use std::task::{Context, Poll};
/// Enum representing a streaming chunk from the model /// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum RawStreamingChoice<R> { pub enum RawStreamingChoice<R: Clone> {
/// A text chunk from a message response /// A text chunk from a message response
Message(String), Message(String),
@ -35,6 +36,7 @@ pub enum RawStreamingChoice<R> {
} }
/// Enum representing a streaming chunk from the model /// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)]
pub enum StreamingChoice { pub enum StreamingChoice {
/// A text chunk from a message response /// A text chunk from a message response
Message(String), Message(String),
@ -61,7 +63,7 @@ pub type StreamingResult<R> =
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>; pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
pub struct StreamingCompletionResponse<R> { pub struct StreamingCompletionResponse<R: Clone + Unpin> {
inner: StreamingResult<R>, inner: StreamingResult<R>,
text: String, text: String,
tool_calls: Vec<(String, String, serde_json::Value)>, tool_calls: Vec<(String, String, serde_json::Value)>,
@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse<R> {
pub response: Option<R>, pub response: Option<R>,
} }
impl<R> StreamingCompletionResponse<R> { impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> { pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
Self { Self {
inner, inner,
@ -81,37 +83,43 @@ impl<R> StreamingCompletionResponse<R> {
} }
} }
impl<R> Stream for StreamingCompletionResponse<R> { impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
type Item = Result<StreamingChoice, CompletionError>; type Item = Result<StreamingChoice, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next(cx) { let stream = self.get_mut();
match stream.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(None) => { Poll::Ready(None) => {
let content = vec![AssistantContent::text(self.text.clone())]; let content = vec![AssistantContent::text(stream.text.clone())];
self.tool_calls stream.tool_calls.iter().for_each(|(n, d, a)| {
.iter() AssistantContent::tool_call(n, d, a.clone());
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a)); });
self.message = Message::Assistant { stream.message = Message::Assistant {
content: content.into(), content: OneOrMany::many(content)
} .expect("There should be at least one assistant message"),
};
Poll::Ready(None)
} }
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice { Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => { RawStreamingChoice::Message(text) => {
self.text = format!("{}{}", self.text, text); stream.text = format!("{}{}", stream.text, text.clone());
Poll::Ready(Some(Ok(choice.clone()))) Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
} }
RawStreamingChoice::ToolCall(name, description, args) => { RawStreamingChoice::ToolCall(name, description, args) => {
self.tool_calls stream
.push((name, description, args)); .tool_calls
Poll::Ready(Some(Ok(choice.clone()))) .push((name.clone(), description.clone(), args.clone()));
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
} }
RawStreamingChoice::FinalResponse(response) => { RawStreamingChoice::FinalResponse(response) => {
self.response = Some(response); stream.response = Some(response);
Poll::Pending Poll::Pending
} }
}, },
@ -120,7 +128,7 @@ impl<R> Stream for StreamingCompletionResponse<R> {
} }
/// Trait for high-level streaming prompt interface /// Trait for high-level streaming prompt interface
pub trait StreamingPrompt<R>: Send + Sync { pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
/// Stream a simple prompt to the model /// Stream a simple prompt to the model
fn stream_prompt( fn stream_prompt(
&self, &self,
@ -129,7 +137,7 @@ pub trait StreamingPrompt<R>: Send + Sync {
} }
/// Trait for high-level streaming chat interface /// Trait for high-level streaming chat interface
pub trait StreamingChat<R>: Send + Sync { pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
/// Stream a chat with history to the model /// Stream a chat with history to the model
fn stream_chat( fn stream_chat(
&self, &self,
@ -150,18 +158,20 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
/// Trait defining a streaming completion model /// Trait defining a streaming completion model
pub trait StreamingCompletionModel: CompletionModel { pub trait StreamingCompletionModel: CompletionModel {
type Response; type StreamingResponse: Clone + Unpin;
/// Stream a completion response for the given request /// Stream a completion response for the given request
fn stream( fn stream(
&self, &self,
request: CompletionRequest, request: CompletionRequest,
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>; ) -> impl Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
>;
} }
/// helper function to stream a completion request to stdout /// helper function to stream a completion request to stdout
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>( pub async fn stream_to_stdout<M: StreamingCompletionModel>(
agent: Agent<M>, agent: Agent<M>,
stream: &mut StreamingCompletionResponse<R>, stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
) -> Result<(), std::io::Error> { ) -> Result<(), std::io::Error> {
print!("Response: "); print!("Response: ");
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {