mirror of https://github.com/0xplaygrounds/rig
207 lines
7.0 KiB
Rust
207 lines
7.0 KiB
Rust
//! This module provides functionality for working with streaming completion models.
|
|
//! It provides traits and types for generating streaming completion requests and
|
|
//! handling streaming completion responses.
|
|
//!
|
|
//! The main traits defined in this module are:
|
|
//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
|
|
//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
|
|
//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
|
|
//! - [StreamingCompletionModel]: Defines a streaming completion model interface
|
|
//!
|
|
|
|
use crate::agent::Agent;
|
|
use crate::completion::{
|
|
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
|
};
|
|
use crate::message::AssistantContent;
|
|
use crate::OneOrMany;
|
|
use futures::{Stream, StreamExt};
|
|
use std::boxed::Box;
|
|
use std::fmt::{Display, Formatter};
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
/// Enum representing a streaming chunk from the model
|
|
#[derive(Debug, Clone)]
|
|
pub enum RawStreamingChoice<R: Clone> {
|
|
/// A text chunk from a message response
|
|
Message(String),
|
|
|
|
/// A tool call response chunk
|
|
ToolCall(String, String, serde_json::Value),
|
|
|
|
/// The final response object
|
|
FinalResponse(R),
|
|
}
|
|
|
|
/// Enum representing a streaming chunk from the model
|
|
#[derive(Debug, Clone)]
|
|
pub enum StreamingChoice {
|
|
/// A text chunk from a message response
|
|
Message(String),
|
|
|
|
/// A tool call response chunk
|
|
ToolCall(String, String, serde_json::Value),
|
|
}
|
|
|
|
impl Display for StreamingChoice {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
StreamingChoice::Message(text) => write!(f, "{}", text),
|
|
StreamingChoice::ToolCall(name, id, params) => {
|
|
write!(f, "Tool call: {} {} {:?}", name, id, params)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
pub type StreamingResult<R> =
|
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
|
|
|
#[cfg(target_arch = "wasm32")]
|
|
pub type StreamingResult<R> =
|
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
|
|
|
|
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
|
inner: StreamingResult<R>,
|
|
text: String,
|
|
tool_calls: Vec<(String, String, serde_json::Value)>,
|
|
pub message: Message,
|
|
pub response: Option<R>,
|
|
}
|
|
|
|
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
|
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
|
Self {
|
|
inner,
|
|
text: "".to_string(),
|
|
tool_calls: vec![],
|
|
message: Message::assistant(""),
|
|
response: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
|
type Item = Result<StreamingChoice, CompletionError>;
|
|
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let stream = self.get_mut();
|
|
|
|
match stream.inner.as_mut().poll_next(cx) {
|
|
Poll::Pending => Poll::Pending,
|
|
Poll::Ready(None) => {
|
|
let mut content = vec![];
|
|
|
|
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
|
content.push(AssistantContent::tool_call(n, d, a.clone()));
|
|
});
|
|
|
|
if content.is_empty() || !stream.text.is_empty() {
|
|
content.insert(0, AssistantContent::text(stream.text.clone()));
|
|
}
|
|
|
|
stream.message = Message::Assistant {
|
|
content: OneOrMany::many(content)
|
|
.expect("There should be at least one assistant message"),
|
|
};
|
|
|
|
Poll::Ready(None)
|
|
}
|
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
|
RawStreamingChoice::Message(text) => {
|
|
stream.text = format!("{}{}", stream.text, text.clone());
|
|
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
|
|
}
|
|
RawStreamingChoice::ToolCall(name, description, args) => {
|
|
stream
|
|
.tool_calls
|
|
.push((name.clone(), description.clone(), args.clone()));
|
|
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
|
|
}
|
|
RawStreamingChoice::FinalResponse(response) => {
|
|
stream.response = Some(response);
|
|
|
|
stream.poll_next_unpin(cx)
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Trait for high-level streaming prompt interface
|
|
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
|
|
/// Stream a simple prompt to the model
|
|
fn stream_prompt(
|
|
&self,
|
|
prompt: &str,
|
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
|
}
|
|
|
|
/// Trait for high-level streaming chat interface
|
|
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
|
|
/// Stream a chat with history to the model
|
|
fn stream_chat(
|
|
&self,
|
|
prompt: &str,
|
|
chat_history: Vec<Message>,
|
|
) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
|
|
}
|
|
|
|
/// Trait for low-level streaming completion interface
|
|
pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
|
/// Generate a streaming completion from a request
|
|
fn stream_completion(
|
|
&self,
|
|
prompt: impl Into<Message> + Send,
|
|
chat_history: Vec<Message>,
|
|
) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
|
|
}
|
|
|
|
/// Trait defining a streaming completion model
|
|
pub trait StreamingCompletionModel: CompletionModel {
|
|
type StreamingResponse: Clone + Unpin;
|
|
/// Stream a completion response for the given request
|
|
fn stream(
|
|
&self,
|
|
request: CompletionRequest,
|
|
) -> impl Future<
|
|
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
|
|
>;
|
|
}
|
|
|
|
/// helper function to stream a completion request to stdout
|
|
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
|
agent: Agent<M>,
|
|
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
|
|
) -> Result<(), std::io::Error> {
|
|
print!("Response: ");
|
|
while let Some(chunk) = stream.next().await {
|
|
match chunk {
|
|
Ok(StreamingChoice::Message(text)) => {
|
|
print!("{}", text);
|
|
std::io::Write::flush(&mut std::io::stdout())?;
|
|
}
|
|
Ok(StreamingChoice::ToolCall(name, _, params)) => {
|
|
let res = agent
|
|
.tools
|
|
.call(&name, params.to_string())
|
|
.await
|
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
|
println!("\nResult: {}", res);
|
|
}
|
|
Err(e) => {
|
|
eprintln!("Error: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
println!(); // New line after streaming completes
|
|
|
|
Ok(())
|
|
}
|