mirror of https://github.com/0xplaygrounds/rig
feat: start refactoring streaming api
This commit is contained in:
parent
92c91d23c3
commit
fcbe648f77
|
@ -8,7 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message::MessageError;
|
||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
|
@ -191,12 +191,12 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
fn handle_event(
|
||||
event: &StreamingEvent,
|
||||
current_tool_call: &mut Option<ToolCallState>,
|
||||
) -> Option<Result<StreamingChoice, CompletionError>> {
|
||||
) -> Option<Result<RawStreamingChoice, CompletionError>> {
|
||||
match event {
|
||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
if current_tool_call.is_none() {
|
||||
return Some(Ok(StreamingChoice::Message(text.clone())));
|
||||
return Some(Ok(RawStreamingChoice::Message(text.clone())));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ fn handle_event(
|
|||
&tool_call.input_json
|
||||
};
|
||||
match serde_json::from_str(json_str) {
|
||||
Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
|
||||
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
|
||||
tool_call.name,
|
||||
tool_call.id,
|
||||
json_value,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -570,10 +570,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
// Azure OpenAI Streaming API
|
||||
// -----------------------------------------------------
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self
|
||||
.client
|
||||
|
|
|
@ -10,8 +10,9 @@
|
|||
//! ```
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
|
@ -463,13 +464,17 @@ impl CompletionModel for DeepSeekCompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/v1/chat/completions").json(&request);
|
||||
send_compatible_streaming_request(builder).await
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
use super::openai;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -495,10 +495,18 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -74,9 +74,9 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
|
||||
match choice.content.parts.first() {
|
||||
super::completion::gemini_api_types::Part::Text(text)
|
||||
=> yield Ok(streaming::StreamingChoice::Message(text)),
|
||||
=> yield Ok(streaming::RawStreamingChoice::Message(text)),
|
||||
super::completion::gemini_api_types::Part::FunctionCall(function_call)
|
||||
=> yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||
_ => panic!("Unsupported response type with streaming.")
|
||||
};
|
||||
}
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
//! ```
|
||||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::providers::openai;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -363,10 +364,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -390,13 +390,17 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
merge_inplace(&mut request, json!({"stream": true}));
|
||||
merge_inplace(
|
||||
&mut request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
@ -526,8 +530,10 @@ mod image_generation {
|
|||
// ======================================
|
||||
// Hyperbolic Audio Generation API
|
||||
// ======================================
|
||||
use crate::providers::openai;
|
||||
#[cfg(feature = "audio")]
|
||||
pub use audio_generation::*;
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
mod audio_generation {
|
||||
use super::{ApiResponse, Client};
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
//! ```
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -24,6 +24,7 @@ use serde_json::{json, Value};
|
|||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
use tracing;
|
||||
use crate::providers::openai;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MiraError {
|
||||
|
@ -347,10 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -228,10 +228,18 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
request = merge(
|
||||
request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||
//! ```
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -47,7 +47,7 @@ use crate::{
|
|||
extractor::ExtractorBuilder,
|
||||
json_utils, message,
|
||||
message::{ImageDetail, Text},
|
||||
Embed, OneOrMany,
|
||||
streaming, Embed, OneOrMany,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
|
@ -405,8 +405,30 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct StreamingCompletionResponse {
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub load_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_duration: Option<u64>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
|
||||
type Response = StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request_payload = self.create_completion_request(request)?;
|
||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||
|
||||
|
@ -426,7 +448,8 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
return Err(CompletionError::ProviderError(err_text));
|
||||
}
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let mut
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
|
@ -456,13 +479,13 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
match response.message {
|
||||
Message::Assistant{ content, tool_calls, .. } => {
|
||||
if !content.is_empty() {
|
||||
yield Ok(StreamingChoice::Message(content))
|
||||
yield Ok(RawStreamingChoice::Message(content))
|
||||
}
|
||||
|
||||
for tool_call in tool_calls.iter() {
|
||||
let function = tool_call.function.clone();
|
||||
|
||||
yield Ok(StreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||
yield Ok(RawStreamingChoice::ToolCall(function.name, "".to_string(), function.arguments));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
@ -471,7 +494,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::completion::CompletionModel;
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::Usage;
|
||||
use crate::streaming;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use async_stream::stream;
|
||||
|
@ -10,6 +11,7 @@ use reqwest::RequestBuilder;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tokio::stream;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Completion Streaming API
|
||||
|
@ -42,15 +44,21 @@ struct StreamingChoice {
|
|||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingCompletionResponse {
|
||||
struct StreamingCompletionChunk {
|
||||
choices: Vec<StreamingChoice>,
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub struct StreamingCompletionResponse {
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
||||
|
@ -61,7 +69,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
|
||||
pub async fn send_compatible_streaming_request(
|
||||
request_builder: RequestBuilder,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
|
||||
let response = request_builder.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
|
@ -73,7 +81,7 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
|
||||
// Handle OpenAI Compatible SSE chunks
|
||||
Ok(Box::pin(stream! {
|
||||
let inner = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let mut partial_data = None;
|
||||
|
@ -121,7 +129,7 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
}
|
||||
|
||||
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
|
||||
let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
|
||||
|
||||
let Ok(data) = data else {
|
||||
continue;
|
||||
|
@ -162,13 +170,17 @@ pub async fn send_compatible_streaming_request(
|
|||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &choice.delta.content {
|
||||
yield Ok(streaming::StreamingChoice::Message(content.clone()))
|
||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||
}
|
||||
|
||||
if &data.usage.is_some() {
|
||||
usage = data.usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -178,7 +190,11 @@ pub async fn send_compatible_streaming_request(
|
|||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(
|
||||
inner,
|
||||
))
|
||||
}
|
||||
|
|
|
@ -13,14 +13,28 @@ use crate::agent::Agent;
|
|||
use crate::completion::{
|
||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
||||
};
|
||||
use crate::message::AssistantContent;
|
||||
use futures::{Stream, StreamExt};
|
||||
use std::boxed::Box;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RawStreamingChoice<R> {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
||||
/// A tool call response chunk
|
||||
ToolCall(String, String, serde_json::Value),
|
||||
|
||||
/// The final response object
|
||||
FinalResponse(R),
|
||||
}
|
||||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug)]
|
||||
pub enum StreamingChoice {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
@ -41,29 +55,87 @@ impl Display for StreamingChoice {
|
|||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub type StreamingResult =
|
||||
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
|
||||
pub type StreamingResult<R> =
|
||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
|
||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
|
||||
|
||||
pub struct StreamingCompletionResponse<R> {
|
||||
inner: StreamingResult<R>,
|
||||
text: String,
|
||||
tool_calls: Vec<(String, String, serde_json::Value)>,
|
||||
pub message: Message,
|
||||
pub response: Option<R>,
|
||||
}
|
||||
|
||||
impl<R> StreamingCompletionResponse<R> {
|
||||
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||
Self {
|
||||
inner,
|
||||
text: "".to_string(),
|
||||
tool_calls: vec![],
|
||||
message: Message::assistant(""),
|
||||
response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for StreamingCompletionResponse<R> {
|
||||
type Item = Result<StreamingChoice, CompletionError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
||||
Poll::Ready(None) => {
|
||||
let content = vec![AssistantContent::text(self.text.clone())];
|
||||
|
||||
self.tool_calls
|
||||
.iter()
|
||||
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
|
||||
|
||||
self.message = Message::Assistant {
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||
RawStreamingChoice::Message(text) => {
|
||||
self.text = format!("{}{}", self.text, text);
|
||||
Poll::Ready(Some(Ok(choice.clone())))
|
||||
}
|
||||
RawStreamingChoice::ToolCall(name, description, args) => {
|
||||
self.tool_calls
|
||||
.push((name, description, args));
|
||||
Poll::Ready(Some(Ok(choice.clone())))
|
||||
}
|
||||
RawStreamingChoice::FinalResponse(response) => {
|
||||
self.response = Some(response);
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for high-level streaming prompt interface
|
||||
pub trait StreamingPrompt: Send + Sync {
|
||||
pub trait StreamingPrompt<R>: Send + Sync {
|
||||
/// Stream a simple prompt to the model
|
||||
fn stream_prompt(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||
}
|
||||
|
||||
/// Trait for high-level streaming chat interface
|
||||
pub trait StreamingChat: Send + Sync {
|
||||
pub trait StreamingChat<R>: Send + Sync {
|
||||
/// Stream a chat with history to the model
|
||||
fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
||||
}
|
||||
|
||||
/// Trait for low-level streaming completion interface
|
||||
|
@ -78,17 +150,18 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
|||
|
||||
/// Trait defining a streaming completion model
|
||||
pub trait StreamingCompletionModel: CompletionModel {
|
||||
type Response;
|
||||
/// Stream a completion response for the given request
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
|
||||
}
|
||||
|
||||
/// helper function to stream a completion request to stdout
|
||||
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
|
||||
agent: Agent<M>,
|
||||
stream: &mut StreamingResult,
|
||||
stream: &mut StreamingCompletionResponse<R>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
print!("Response: ");
|
||||
while let Some(chunk) = stream.next().await {
|
||||
|
@ -111,6 +184,7 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!(); // New line after streaming completes
|
||||
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in New Issue