fix: compiles + formatted

This commit is contained in:
yavens 2025-04-09 12:54:55 -04:00
parent fcbe648f77
commit 86b84c82fb
20 changed files with 175 additions and 103 deletions

View File

@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> {
.await?;
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
Ok(())
}

View File

@ -110,23 +110,20 @@ use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::streaming::StreamingCompletionResponse;
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
async fn stream_prompt(
&self,
prompt: &str,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()

View File

@ -67,7 +67,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::OneOrMany;
use crate::{
json_utils,
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
/// Stream the completion request
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
pub async fn stream(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
let model = self.model.clone();
model.stream(self.build()).await
}

View File

@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace;
use crate::message::MessageError;
use crate::streaming;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)]
@ -61,7 +62,7 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone)]
pub struct PartialUsage {
pub output_tokens: usize,
#[serde(default)]
@ -75,11 +76,18 @@ struct ToolCallState {
input_json: String,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: PartialUsage,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens
} else if let Some(tokens) = self.default_max_tokens {
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
// Use our SSE decoder to directly handle Server-Sent Events format
let sse_stream = sse_from_response(response);
Ok(Box::pin(stream! {
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
let mut current_tool_call: Option<ToolCallState> = None;
let mut sse_stream = Box::pin(sse_stream);
let mut input_tokens = 0;
while let Some(sse_result) = sse_stream.next().await {
match sse_result {
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
// Parse the SSE data as a StreamingEvent
match serde_json::from_str::<StreamingEvent>(&sse.data) {
Ok(event) => {
match &event {
StreamingEvent::MessageStart { message } => {
input_tokens = message.usage.input_tokens;
},
StreamingEvent::MessageDelta { delta, usage } => {
if delta.stop_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: PartialUsage {
output_tokens: usage.output_tokens,
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
}
}))
}
}
_ => {}
}
if let Some(result) = handle_event(&event, &mut current_tool_call) {
yield result;
}
@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel {
}
}
}
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}
fn handle_event(
event: &StreamingEvent,
current_tool_call: &mut Option<ToolCallState>,
) -> Option<Result<RawStreamingChoice, CompletionError>> {
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
match event {
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => {

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel {
// Azure OpenAI Streaming API
// -----------------------------------------------------
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(

View File

@ -12,7 +12,7 @@
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder,
@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel {
}
impl StreamingCompletionModel for DeepSeekCompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(

View File

@ -13,7 +13,7 @@
use super::openai;
use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(

View File

@ -609,7 +609,7 @@ pub mod gemini_api_types {
HarmCategoryCivicIntegrity,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub prompt_token_count: i32,

View File

@ -2,26 +2,34 @@ use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
use crate::providers::gemini::completion::gemini_api_types::UsageMetadata;
use crate::{
completion::{CompletionError, CompletionRequest},
streaming::{self, StreamingCompletionModel, StreamingResult},
streaming::{self, StreamingCompletionModel},
};
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StreamGenerateContentResponse {
/// Candidate responses from the model.
pub candidates: Vec<ContentCandidate>,
pub model_version: Option<String>,
pub usage_metadata: UsageMetadata,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage_metadata: UsageMetadata,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = create_request_body(completion_request)?;
let response = self
@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel {
)));
}
Ok(Box::pin(stream! {
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel {
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
_ => panic!("Unsupported response type with streaming.")
};
if choice.finish_reason.is_some() {
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage_metadata: data.usage_metadata,
}))
}
}
}
}))
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -11,7 +11,7 @@
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
use crate::json_utils::merge;
use crate::providers::openai;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(

View File

@ -1,9 +1,9 @@
use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge_inplace;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
use crate::streaming::StreamingCompletionModel;
use crate::{json_utils, streaming};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::convert::Infallible;
@ -55,14 +55,19 @@ struct CompletionChunk {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_request_body(&completion_request)?;
// Enable streaming
merge_inplace(&mut request, json!({"stream": true}));
merge_inplace(
&mut request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
if let Some(ref params) = completion_request.additional_params {
merge_inplace(&mut request, params.clone());

View File

@ -12,7 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
merge_inplace(

View File

@ -8,8 +8,9 @@
//!
//! ```
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -24,7 +25,6 @@ use serde_json::{json, Value};
use std::string::FromUtf8Error;
use thiserror::Error;
use tracing;
use crate::providers::openai;
#[derive(Debug, Error)]
pub enum MiraError {
@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -11,7 +11,7 @@
use crate::json_utils::merge;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type Response = openai::StreamingCompletionResponse;
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(request)?;
request = merge(

View File

@ -39,7 +39,7 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ```
use crate::json_utils::merge_inplace;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel {
}
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub load_duration: Option<u64>,
#[serde(default)]
pub prompt_eval_count: Option<u64>,
#[serde(default)]
pub prompt_eval_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u64>,
#[serde(default)]
pub eval_duration: Option<u64>,
}
impl StreamingCompletionModel for CompletionModel {
type Response = StreamingCompletionResponse;
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request_payload = self.create_completion_request(request)?;
merge_inplace(&mut request_payload, json!({"stream": true}));
@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel {
return Err(CompletionError::ProviderError(err_text));
}
let mut
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel {
}
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -4,14 +4,13 @@ use crate::json_utils;
use crate::json_utils::merge;
use crate::providers::openai::Usage;
use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use tokio::stream;
// ================================================================
// OpenAI Completion Streaming API
@ -47,18 +46,21 @@ struct StreamingChoice {
struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>,
usage: Option<Usage>,
finish_reason: Option<String>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
usage: Option<Usage>,
pub usage: Option<Usage>,
}
impl StreamingCompletionModel for CompletionModel {
type Response = StreamingCompletionResponse;
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));
@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request(
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
}
if &data.usage.is_some() {
usage = data.usage;
if data.finish_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: data.usage
}))
}
}
}
@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request(
}
});
Ok(streaming::StreamingCompletionResponse::new(
inner,
))
Ok(streaming::StreamingCompletionResponse::new(inner))
}

View File

@ -18,8 +18,9 @@ use crate::{
use crate::completion::CompletionRequest;
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: completion::CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -1,18 +1,21 @@
use serde_json::json;
use super::completion::CompletionModel;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::StreamingCompletionResponse;
use crate::{
completion::{CompletionError, CompletionRequest},
json_utils::merge,
streaming::{StreamingCompletionModel, StreamingResult},
streaming::StreamingCompletionModel,
};
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream_tokens": true}));

View File

@ -1,15 +1,17 @@
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge;
use crate::providers::openai;
use crate::providers::openai::send_compatible_streaming_request;
use crate::providers::xai::completion::CompletionModel;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
use serde_json::json;
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = openai::StreamingCompletionResponse;
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));

View File

@ -14,6 +14,7 @@ use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
};
use crate::message::AssistantContent;
use crate::OneOrMany;
use futures::{Stream, StreamExt};
use std::boxed::Box;
use std::fmt::{Display, Formatter};
@ -23,7 +24,7 @@ use std::task::{Context, Poll};
/// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)]
pub enum RawStreamingChoice<R> {
pub enum RawStreamingChoice<R: Clone> {
/// A text chunk from a message response
Message(String),
@ -35,6 +36,7 @@ pub enum RawStreamingChoice<R> {
}
/// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)]
pub enum StreamingChoice {
/// A text chunk from a message response
Message(String),
@ -61,7 +63,7 @@ pub type StreamingResult<R> =
#[cfg(target_arch = "wasm32")]
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
pub struct StreamingCompletionResponse<R> {
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
inner: StreamingResult<R>,
text: String,
tool_calls: Vec<(String, String, serde_json::Value)>,
@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse<R> {
pub response: Option<R>,
}
impl<R> StreamingCompletionResponse<R> {
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
Self {
inner,
@ -81,37 +83,43 @@ impl<R> StreamingCompletionResponse<R> {
}
}
impl<R> Stream for StreamingCompletionResponse<R> {
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
type Item = Result<StreamingChoice, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next(cx) {
let stream = self.get_mut();
match stream.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
let content = vec![AssistantContent::text(self.text.clone())];
let content = vec![AssistantContent::text(stream.text.clone())];
self.tool_calls
.iter()
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
stream.tool_calls.iter().for_each(|(n, d, a)| {
AssistantContent::tool_call(n, d, a.clone());
});
self.message = Message::Assistant {
content: content.into(),
}
stream.message = Message::Assistant {
content: OneOrMany::many(content)
.expect("There should be at least one assistant message"),
};
Poll::Ready(None)
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => {
self.text = format!("{}{}", self.text, text);
Poll::Ready(Some(Ok(choice.clone())))
stream.text = format!("{}{}", stream.text, text.clone());
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
}
RawStreamingChoice::ToolCall(name, description, args) => {
self.tool_calls
.push((name, description, args));
Poll::Ready(Some(Ok(choice.clone())))
stream
.tool_calls
.push((name.clone(), description.clone(), args.clone()));
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
}
RawStreamingChoice::FinalResponse(response) => {
self.response = Some(response);
stream.response = Some(response);
Poll::Pending
}
},
@ -120,7 +128,7 @@ impl<R> Stream for StreamingCompletionResponse<R> {
}
/// Trait for high-level streaming prompt interface
pub trait StreamingPrompt<R>: Send + Sync {
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
/// Stream a simple prompt to the model
fn stream_prompt(
&self,
@ -129,7 +137,7 @@ pub trait StreamingPrompt<R>: Send + Sync {
}
/// Trait for high-level streaming chat interface
pub trait StreamingChat<R>: Send + Sync {
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
/// Stream a chat with history to the model
fn stream_chat(
&self,
@ -150,18 +158,20 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
/// Trait defining a streaming completion model
pub trait StreamingCompletionModel: CompletionModel {
type Response;
type StreamingResponse: Clone + Unpin;
/// Stream a completion response for the given request
fn stream(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
) -> impl Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
>;
}
/// helper function to stream a completion request to stdout
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
agent: Agent<M>,
stream: &mut StreamingCompletionResponse<R>,
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
) -> Result<(), std::io::Error> {
print!("Response: ");
while let Some(chunk) = stream.next().await {