mirror of https://github.com/0xplaygrounds/rig
fix: compiles + formatted
This commit is contained in:
parent
fcbe648f77
commit
86b84c82fb
|
@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?}", response.usage)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,23 +110,20 @@ use std::collections::HashMap;
|
||||||
|
|
||||||
use futures::{stream, StreamExt, TryStreamExt};
|
use futures::{stream, StreamExt, TryStreamExt};
|
||||||
|
|
||||||
|
use crate::streaming::StreamingCompletionResponse;
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
use crate::tool::McpTool;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{
|
completion::{
|
||||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||||
Message, Prompt, PromptError,
|
Message, Prompt, PromptError,
|
||||||
},
|
},
|
||||||
message::AssistantContent,
|
message::AssistantContent,
|
||||||
streaming::{
|
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
|
||||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
|
||||||
StreamingResult,
|
|
||||||
},
|
|
||||||
tool::{Tool, ToolSet},
|
tool::{Tool, ToolSet},
|
||||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
use crate::tool::McpTool;
|
|
||||||
|
|
||||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||||
/// All context documents and tools are always provided to the agent when prompted.
|
/// All context documents and tools are always provided to the agent when prompted.
|
||||||
|
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
|
||||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
async fn stream_prompt(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
self.stream_chat(prompt, vec![]).await
|
self.stream_chat(prompt, vec![]).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
|
||||||
async fn stream_chat(
|
async fn stream_chat(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
chat_history: Vec<Message>,
|
chat_history: Vec<Message>,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
self.stream_completion(prompt, chat_history)
|
self.stream_completion(prompt, chat_history)
|
||||||
.await?
|
.await?
|
||||||
.stream()
|
.stream()
|
||||||
|
|
|
@ -67,7 +67,7 @@ use std::collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::OneOrMany;
|
use crate::OneOrMany;
|
||||||
use crate::{
|
use crate::{
|
||||||
json_utils,
|
json_utils,
|
||||||
|
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
||||||
/// Stream the completion request
|
/// Stream the completion request
|
||||||
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
|
pub async fn stream(
|
||||||
|
self,
|
||||||
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
model.stream(self.build()).await
|
model.stream(self.build()).await
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::message::MessageError;
|
use crate::message::MessageError;
|
||||||
|
use crate::streaming;
|
||||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -61,7 +62,7 @@ pub struct MessageDelta {
|
||||||
pub stop_sequence: Option<String>,
|
pub stop_sequence: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct PartialUsage {
|
pub struct PartialUsage {
|
||||||
pub output_tokens: usize,
|
pub output_tokens: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -75,11 +76,18 @@ struct ToolCallState {
|
||||||
input_json: String,
|
input_json: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage: PartialUsage,
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
||||||
tokens
|
tokens
|
||||||
} else if let Some(tokens) = self.default_max_tokens {
|
} else if let Some(tokens) = self.default_max_tokens {
|
||||||
|
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
// Use our SSE decoder to directly handle Server-Sent Events format
|
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||||
let sse_stream = sse_from_response(response);
|
let sse_stream = sse_from_response(response);
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
|
||||||
let mut current_tool_call: Option<ToolCallState> = None;
|
let mut current_tool_call: Option<ToolCallState> = None;
|
||||||
let mut sse_stream = Box::pin(sse_stream);
|
let mut sse_stream = Box::pin(sse_stream);
|
||||||
|
let mut input_tokens = 0;
|
||||||
|
|
||||||
while let Some(sse_result) = sse_stream.next().await {
|
while let Some(sse_result) = sse_stream.next().await {
|
||||||
match sse_result {
|
match sse_result {
|
||||||
|
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
// Parse the SSE data as a StreamingEvent
|
// Parse the SSE data as a StreamingEvent
|
||||||
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
|
match &event {
|
||||||
|
StreamingEvent::MessageStart { message } => {
|
||||||
|
input_tokens = message.usage.input_tokens;
|
||||||
|
},
|
||||||
|
StreamingEvent::MessageDelta { delta, usage } => {
|
||||||
|
if delta.stop_reason.is_some() {
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: PartialUsage {
|
||||||
|
output_tokens: usage.output_tokens,
|
||||||
|
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
||||||
yield result;
|
yield result;
|
||||||
}
|
}
|
||||||
|
@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_event(
|
fn handle_event(
|
||||||
event: &StreamingEvent,
|
event: &StreamingEvent,
|
||||||
current_tool_call: &mut Option<ToolCallState>,
|
current_tool_call: &mut Option<ToolCallState>,
|
||||||
) -> Option<Result<RawStreamingChoice, CompletionError>> {
|
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
|
||||||
match event {
|
match event {
|
||||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||||
ContentDelta::TextDelta { text } => {
|
ContentDelta::TextDelta { text } => {
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
// Azure OpenAI Streaming API
|
// Azure OpenAI Streaming API
|
||||||
// -----------------------------------------------------
|
// -----------------------------------------------------
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(
|
request = merge(
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai;
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||||
extractor::ExtractorBuilder,
|
extractor::ExtractorBuilder,
|
||||||
|
@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(
|
request = merge(
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
use super::openai;
|
use super::openai;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(
|
request = merge(
|
||||||
|
|
|
@ -609,7 +609,7 @@ pub mod gemini_api_types {
|
||||||
HarmCategoryCivicIntegrity,
|
HarmCategoryCivicIntegrity,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct UsageMetadata {
|
pub struct UsageMetadata {
|
||||||
pub prompt_token_count: i32,
|
pub prompt_token_count: i32,
|
||||||
|
|
|
@ -2,26 +2,34 @@ use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||||
|
use crate::providers::gemini::completion::gemini_api_types::UsageMetadata;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionError, CompletionRequest},
|
completion::{CompletionError, CompletionRequest},
|
||||||
streaming::{self, StreamingCompletionModel, StreamingResult},
|
streaming::{self, StreamingCompletionModel},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct StreamGenerateContentResponse {
|
pub struct StreamGenerateContentResponse {
|
||||||
/// Candidate responses from the model.
|
/// Candidate responses from the model.
|
||||||
pub candidates: Vec<ContentCandidate>,
|
pub candidates: Vec<ContentCandidate>,
|
||||||
pub model_version: Option<String>,
|
pub model_version: Option<String>,
|
||||||
|
pub usage_metadata: UsageMetadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage_metadata: UsageMetadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let request = create_request_body(completion_request)?;
|
let request = create_request_body(completion_request)?;
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
|
@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::pin(stream! {
|
let stream = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||||
_ => panic!("Unsupported response type with streaming.")
|
_ => panic!("Unsupported response type with streaming.")
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if choice.finish_reason.is_some() {
|
||||||
|
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage_metadata: data.usage_metadata,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai;
|
use crate::providers::openai;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(
|
request = merge(
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use super::completion::CompletionModel;
|
use super::completion::CompletionModel;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils;
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::StreamingCompletionModel;
|
||||||
|
use crate::{json_utils, streaming};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
@ -55,14 +55,19 @@ struct CompletionChunk {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request = self.create_request_body(&completion_request)?;
|
let mut request = self.create_request_body(&completion_request)?;
|
||||||
|
|
||||||
// Enable streaming
|
// Enable streaming
|
||||||
merge_inplace(&mut request, json!({"stream": true}));
|
merge_inplace(
|
||||||
|
&mut request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(ref params) = completion_request.additional_params {
|
if let Some(ref params) = completion_request.additional_params {
|
||||||
merge_inplace(&mut request, params.clone());
|
merge_inplace(&mut request, params.clone());
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||||
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
merge_inplace(
|
merge_inplace(
|
||||||
|
|
|
@ -8,8 +8,9 @@
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -24,7 +25,6 @@ use serde_json::{json, Value};
|
||||||
use std::string::FromUtf8Error;
|
use std::string::FromUtf8Error;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing;
|
use tracing;
|
||||||
use crate::providers::openai;
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum MiraError {
|
pub enum MiraError {
|
||||||
|
@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = openai::StreamingCompletionResponse;
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
|
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(request)?;
|
let mut request = self.create_completion_request(request)?;
|
||||||
|
|
||||||
request = merge(
|
request = merge(
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||||
//! ```
|
//! ```
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
completion::{self, CompletionError, CompletionRequest},
|
completion::{self, CompletionError, CompletionRequest},
|
||||||
|
@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct StreamingCompletionResponse {
|
pub struct StreamingCompletionResponse {
|
||||||
#[serde(default)]
|
|
||||||
pub done_reason: Option<String>,
|
pub done_reason: Option<String>,
|
||||||
#[serde(default)]
|
|
||||||
pub total_duration: Option<u64>,
|
pub total_duration: Option<u64>,
|
||||||
#[serde(default)]
|
|
||||||
pub load_duration: Option<u64>,
|
pub load_duration: Option<u64>,
|
||||||
#[serde(default)]
|
|
||||||
pub prompt_eval_count: Option<u64>,
|
pub prompt_eval_count: Option<u64>,
|
||||||
#[serde(default)]
|
|
||||||
pub prompt_eval_duration: Option<u64>,
|
pub prompt_eval_duration: Option<u64>,
|
||||||
#[serde(default)]
|
|
||||||
pub eval_count: Option<u64>,
|
pub eval_count: Option<u64>,
|
||||||
#[serde(default)]
|
|
||||||
pub eval_duration: Option<u64>,
|
pub eval_duration: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = StreamingCompletionResponse;
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
|
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request_payload = self.create_completion_request(request)?;
|
let mut request_payload = self.create_completion_request(request)?;
|
||||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||||
|
|
||||||
|
@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
return Err(CompletionError::ProviderError(err_text));
|
return Err(CompletionError::ProviderError(err_text));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut
|
|
||||||
let stream = Box::pin(stream! {
|
let stream = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,13 @@ use crate::json_utils;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
use crate::providers::openai::Usage;
|
use crate::providers::openai::Usage;
|
||||||
use crate::streaming;
|
use crate::streaming;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use reqwest::RequestBuilder;
|
use reqwest::RequestBuilder;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::stream;
|
|
||||||
|
|
||||||
// ================================================================
|
// ================================================================
|
||||||
// OpenAI Completion Streaming API
|
// OpenAI Completion Streaming API
|
||||||
|
@ -47,18 +46,21 @@ struct StreamingChoice {
|
||||||
struct StreamingCompletionChunk {
|
struct StreamingCompletionChunk {
|
||||||
choices: Vec<StreamingChoice>,
|
choices: Vec<StreamingChoice>,
|
||||||
usage: Option<Usage>,
|
usage: Option<Usage>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct StreamingCompletionResponse {
|
pub struct StreamingCompletionResponse {
|
||||||
usage: Option<Usage>,
|
pub usage: Option<Usage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
type Response = StreamingCompletionResponse;
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
||||||
|
@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request(
|
||||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if &data.usage.is_some() {
|
if data.finish_reason.is_some() {
|
||||||
usage = data.usage;
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: data.usage
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request(
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(streaming::StreamingCompletionResponse::new(
|
Ok(streaming::StreamingCompletionResponse::new(inner))
|
||||||
inner,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,9 @@ use crate::{
|
||||||
|
|
||||||
use crate::completion::CompletionRequest;
|
use crate::completion::CompletionRequest;
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: completion::CompletionRequest,
|
completion_request: completion::CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
use super::completion::CompletionModel;
|
use super::completion::CompletionModel;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
|
use crate::streaming::StreamingCompletionResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionError, CompletionRequest},
|
completion::{CompletionError, CompletionRequest},
|
||||||
json_utils::merge,
|
json_utils::merge,
|
||||||
streaming::{StreamingCompletionModel, StreamingResult},
|
streaming::StreamingCompletionModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream_tokens": true}));
|
request = merge(request, json!({"stream_tokens": true}));
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::providers::openai;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::providers::xai::completion::CompletionModel;
|
use crate::providers::xai::completion::CompletionModel;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||||
async fn stream(
|
async fn stream(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
|
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(request, json!({"stream": true}));
|
||||||
|
|
|
@ -14,6 +14,7 @@ use crate::completion::{
|
||||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
||||||
};
|
};
|
||||||
use crate::message::AssistantContent;
|
use crate::message::AssistantContent;
|
||||||
|
use crate::OneOrMany;
|
||||||
use futures::{Stream, StreamExt};
|
use futures::{Stream, StreamExt};
|
||||||
use std::boxed::Box;
|
use std::boxed::Box;
|
||||||
use std::fmt::{Display, Formatter};
|
use std::fmt::{Display, Formatter};
|
||||||
|
@ -23,7 +24,7 @@ use std::task::{Context, Poll};
|
||||||
|
|
||||||
/// Enum representing a streaming chunk from the model
|
/// Enum representing a streaming chunk from the model
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum RawStreamingChoice<R> {
|
pub enum RawStreamingChoice<R: Clone> {
|
||||||
/// A text chunk from a message response
|
/// A text chunk from a message response
|
||||||
Message(String),
|
Message(String),
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ pub enum RawStreamingChoice<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enum representing a streaming chunk from the model
|
/// Enum representing a streaming chunk from the model
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum StreamingChoice {
|
pub enum StreamingChoice {
|
||||||
/// A text chunk from a message response
|
/// A text chunk from a message response
|
||||||
Message(String),
|
Message(String),
|
||||||
|
@ -61,7 +63,7 @@ pub type StreamingResult<R> =
|
||||||
#[cfg(target_arch = "wasm32")]
|
#[cfg(target_arch = "wasm32")]
|
||||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
|
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
|
||||||
|
|
||||||
pub struct StreamingCompletionResponse<R> {
|
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
||||||
inner: StreamingResult<R>,
|
inner: StreamingResult<R>,
|
||||||
text: String,
|
text: String,
|
||||||
tool_calls: Vec<(String, String, serde_json::Value)>,
|
tool_calls: Vec<(String, String, serde_json::Value)>,
|
||||||
|
@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse<R> {
|
||||||
pub response: Option<R>,
|
pub response: Option<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> StreamingCompletionResponse<R> {
|
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
||||||
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
|
@ -81,37 +83,43 @@ impl<R> StreamingCompletionResponse<R> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Stream for StreamingCompletionResponse<R> {
|
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
type Item = Result<StreamingChoice, CompletionError>;
|
type Item = Result<StreamingChoice, CompletionError>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
match self.inner.poll_next(cx) {
|
let stream = self.get_mut();
|
||||||
|
|
||||||
|
match stream.inner.as_mut().poll_next(cx) {
|
||||||
Poll::Pending => Poll::Pending,
|
Poll::Pending => Poll::Pending,
|
||||||
|
|
||||||
Poll::Ready(None) => {
|
Poll::Ready(None) => {
|
||||||
let content = vec![AssistantContent::text(self.text.clone())];
|
let content = vec![AssistantContent::text(stream.text.clone())];
|
||||||
|
|
||||||
self.tool_calls
|
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
||||||
.iter()
|
AssistantContent::tool_call(n, d, a.clone());
|
||||||
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
|
});
|
||||||
|
|
||||||
self.message = Message::Assistant {
|
stream.message = Message::Assistant {
|
||||||
content: content.into(),
|
content: OneOrMany::many(content)
|
||||||
}
|
.expect("There should be at least one assistant message"),
|
||||||
|
};
|
||||||
|
|
||||||
|
Poll::Ready(None)
|
||||||
}
|
}
|
||||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||||
RawStreamingChoice::Message(text) => {
|
RawStreamingChoice::Message(text) => {
|
||||||
self.text = format!("{}{}", self.text, text);
|
stream.text = format!("{}{}", stream.text, text.clone());
|
||||||
Poll::Ready(Some(Ok(choice.clone())))
|
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
|
||||||
}
|
}
|
||||||
RawStreamingChoice::ToolCall(name, description, args) => {
|
RawStreamingChoice::ToolCall(name, description, args) => {
|
||||||
self.tool_calls
|
stream
|
||||||
.push((name, description, args));
|
.tool_calls
|
||||||
Poll::Ready(Some(Ok(choice.clone())))
|
.push((name.clone(), description.clone(), args.clone()));
|
||||||
|
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
|
||||||
}
|
}
|
||||||
RawStreamingChoice::FinalResponse(response) => {
|
RawStreamingChoice::FinalResponse(response) => {
|
||||||
self.response = Some(response);
|
stream.response = Some(response);
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -120,7 +128,7 @@ impl<R> Stream for StreamingCompletionResponse<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming prompt interface
|
/// Trait for high-level streaming prompt interface
|
||||||
pub trait StreamingPrompt<R>: Send + Sync {
|
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
|
||||||
/// Stream a simple prompt to the model
|
/// Stream a simple prompt to the model
|
||||||
fn stream_prompt(
|
fn stream_prompt(
|
||||||
&self,
|
&self,
|
||||||
|
@ -129,7 +137,7 @@ pub trait StreamingPrompt<R>: Send + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for high-level streaming chat interface
|
/// Trait for high-level streaming chat interface
|
||||||
pub trait StreamingChat<R>: Send + Sync {
|
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
|
||||||
/// Stream a chat with history to the model
|
/// Stream a chat with history to the model
|
||||||
fn stream_chat(
|
fn stream_chat(
|
||||||
&self,
|
&self,
|
||||||
|
@ -150,18 +158,20 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
||||||
|
|
||||||
/// Trait defining a streaming completion model
|
/// Trait defining a streaming completion model
|
||||||
pub trait StreamingCompletionModel: CompletionModel {
|
pub trait StreamingCompletionModel: CompletionModel {
|
||||||
type Response;
|
type StreamingResponse: Clone + Unpin;
|
||||||
/// Stream a completion response for the given request
|
/// Stream a completion response for the given request
|
||||||
fn stream(
|
fn stream(
|
||||||
&self,
|
&self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
|
) -> impl Future<
|
||||||
|
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
|
||||||
|
>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// helper function to stream a completion request to stdout
|
/// helper function to stream a completion request to stdout
|
||||||
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
|
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||||
agent: Agent<M>,
|
agent: Agent<M>,
|
||||||
stream: &mut StreamingCompletionResponse<R>,
|
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
|
||||||
) -> Result<(), std::io::Error> {
|
) -> Result<(), std::io::Error> {
|
||||||
print!("Response: ");
|
print!("Response: ");
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
|
|
Loading…
Reference in New Issue