mirror of https://github.com/0xplaygrounds/rig
fix: compiles + formatted
This commit is contained in:
parent
fcbe648f77
commit
86b84c82fb
|
@ -16,6 +16,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.await?;
|
||||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?}", response.usage)
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -110,23 +110,20 @@ use std::collections::HashMap;
|
|||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::streaming::StreamingCompletionResponse;
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
message::AssistantContent,
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
streaming::{StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt},
|
||||
tool::{Tool, ToolSet},
|
||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
|
@ -500,18 +497,21 @@ impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
impl<M: StreamingCompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
|
|
|
@ -67,7 +67,7 @@ use std::collections::HashMap;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::OneOrMany;
|
||||
use crate::{
|
||||
json_utils,
|
||||
|
@ -467,7 +467,9 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
|
||||
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
||||
/// Stream the completion request
|
||||
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
|
||||
pub async fn stream(
|
||||
self,
|
||||
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
||||
let model = self.model.clone();
|
||||
model.stream(self.build()).await
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use super::decoders::sse::from_response as sse_from_response;
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message::MessageError;
|
||||
use crate::streaming;
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -61,7 +62,7 @@ pub struct MessageDelta {
|
|||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct PartialUsage {
|
||||
pub output_tokens: usize,
|
||||
#[serde(default)]
|
||||
|
@ -75,11 +76,18 @@ struct ToolCallState {
|
|||
input_json: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage: PartialUsage,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
|
||||
tokens
|
||||
} else if let Some(tokens) = self.default_max_tokens {
|
||||
|
@ -155,9 +163,10 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
// Use our SSE decoder to directly handle Server-Sent Events format
|
||||
let sse_stream = sse_from_response(response);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream! {
|
||||
let mut current_tool_call: Option<ToolCallState> = None;
|
||||
let mut sse_stream = Box::pin(sse_stream);
|
||||
let mut input_tokens = 0;
|
||||
|
||||
while let Some(sse_result) = sse_stream.next().await {
|
||||
match sse_result {
|
||||
|
@ -165,6 +174,24 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
// Parse the SSE data as a StreamingEvent
|
||||
match serde_json::from_str::<StreamingEvent>(&sse.data) {
|
||||
Ok(event) => {
|
||||
match &event {
|
||||
StreamingEvent::MessageStart { message } => {
|
||||
input_tokens = message.usage.input_tokens;
|
||||
},
|
||||
StreamingEvent::MessageDelta { delta, usage } => {
|
||||
if delta.stop_reason.is_some() {
|
||||
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: PartialUsage {
|
||||
output_tokens: usage.output_tokens,
|
||||
input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if let Some(result) = handle_event(&event, &mut current_tool_call) {
|
||||
yield result;
|
||||
}
|
||||
|
@ -184,14 +211,16 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_event(
|
||||
event: &StreamingEvent,
|
||||
current_tool_call: &mut Option<ToolCallState>,
|
||||
) -> Option<Result<RawStreamingChoice, CompletionError>> {
|
||||
) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
|
||||
match event {
|
||||
StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -570,11 +570,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
// Azure OpenAI Streaming API
|
||||
// -----------------------------------------------------
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
completion::{self, CompletionError, CompletionModel, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
|
@ -464,11 +464,11 @@ impl CompletionModel for DeepSeekCompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for DeepSeekCompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
use super::openai;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -495,12 +495,12 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(
|
||||
|
|
|
@ -609,7 +609,7 @@ pub mod gemini_api_types {
|
|||
HarmCategoryCivicIntegrity,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub prompt_token_count: i32,
|
||||
|
|
|
@ -2,26 +2,34 @@ use async_stream::stream;
|
|||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||
use crate::providers::gemini::completion::gemini_api_types::UsageMetadata;
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
streaming::{self, StreamingCompletionModel, StreamingResult},
|
||||
streaming::{self, StreamingCompletionModel},
|
||||
};
|
||||
|
||||
use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct StreamGenerateContentResponse {
|
||||
/// Candidate responses from the model.
|
||||
pub candidates: Vec<ContentCandidate>,
|
||||
pub model_version: Option<String>,
|
||||
pub usage_metadata: UsageMetadata,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage_metadata: UsageMetadata,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = create_request_body(completion_request)?;
|
||||
|
||||
let response = self
|
||||
|
@ -42,7 +50,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
)));
|
||||
}
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
|
@ -79,8 +87,16 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
=> yield Ok(streaming::RawStreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
|
||||
_ => panic!("Unsupported response type with streaming.")
|
||||
};
|
||||
|
||||
if choice.finish_reason.is_some() {
|
||||
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage_metadata: data.usage_metadata,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -364,11 +364,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use super::completion::CompletionModel;
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils;
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::providers::openai::{send_compatible_streaming_request, StreamingCompletionResponse};
|
||||
use crate::streaming::StreamingCompletionModel;
|
||||
use crate::{json_utils, streaming};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::convert::Infallible;
|
||||
|
@ -55,14 +55,19 @@ struct CompletionChunk {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request = self.create_request_body(&completion_request)?;
|
||||
|
||||
// Enable streaming
|
||||
merge_inplace(&mut request, json!({"stream": true}));
|
||||
merge_inplace(
|
||||
&mut request,
|
||||
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||
);
|
||||
|
||||
if let Some(ref params) = completion_request.additional_params {
|
||||
merge_inplace(&mut request, params.clone());
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -390,11 +390,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
merge_inplace(
|
||||
|
|
|
@ -8,8 +8,9 @@
|
|||
//!
|
||||
//! ```
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -24,7 +25,6 @@ use serde_json::{json, Value};
|
|||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
use tracing;
|
||||
use crate::providers::openai;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MiraError {
|
||||
|
@ -348,11 +348,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -228,12 +228,12 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = openai::StreamingCompletionResponse;
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(request)?;
|
||||
|
||||
request = merge(
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
|
||||
//! ```
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -405,30 +405,25 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub load_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_duration: Option<u64>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = StreamingCompletionResponse;
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request_payload = self.create_completion_request(request)?;
|
||||
merge_inplace(&mut request_payload, json!({"stream": true}));
|
||||
|
||||
|
@ -448,7 +443,6 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
return Err(CompletionError::ProviderError(err_text));
|
||||
}
|
||||
|
||||
let mut
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
|
@ -495,6 +489,8 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,14 +4,13 @@ use crate::json_utils;
|
|||
use crate::json_utils::merge;
|
||||
use crate::providers::openai::Usage;
|
||||
use crate::streaming;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tokio::stream;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Completion Streaming API
|
||||
|
@ -47,18 +46,21 @@ struct StreamingChoice {
|
|||
struct StreamingCompletionChunk {
|
||||
choices: Vec<StreamingChoice>,
|
||||
usage: Option<Usage>,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
usage: Option<Usage>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type Response = StreamingCompletionResponse;
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::Response>, CompletionError> {
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
||||
|
@ -179,9 +181,12 @@ pub async fn send_compatible_streaming_request(
|
|||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
||||
}
|
||||
|
||||
if &data.usage.is_some() {
|
||||
usage = data.usage;
|
||||
if data.finish_reason.is_some() {
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: data.usage
|
||||
}))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,7 +199,5 @@ pub async fn send_compatible_streaming_request(
|
|||
}
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(
|
||||
inner,
|
||||
))
|
||||
Ok(streaming::StreamingCompletionResponse::new(inner))
|
||||
}
|
||||
|
|
|
@ -18,8 +18,9 @@ use crate::{
|
|||
|
||||
use crate::completion::CompletionRequest;
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
@ -345,10 +346,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use serde_json::json;
|
||||
|
||||
use super::completion::CompletionModel;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::StreamingCompletionResponse;
|
||||
use crate::{
|
||||
completion::{CompletionError, CompletionRequest},
|
||||
json_utils::merge,
|
||||
streaming::{StreamingCompletionModel, StreamingResult},
|
||||
streaming::StreamingCompletionModel,
|
||||
};
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream_tokens": true}));
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge;
|
||||
use crate::providers::openai;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::providers::xai::completion::CompletionModel;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
||||
use serde_json::json;
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = openai::StreamingCompletionResponse;
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
|
|
@ -14,6 +14,7 @@ use crate::completion::{
|
|||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
|
||||
};
|
||||
use crate::message::AssistantContent;
|
||||
use crate::OneOrMany;
|
||||
use futures::{Stream, StreamExt};
|
||||
use std::boxed::Box;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
@ -23,7 +24,7 @@ use std::task::{Context, Poll};
|
|||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RawStreamingChoice<R> {
|
||||
pub enum RawStreamingChoice<R: Clone> {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
||||
|
@ -35,6 +36,7 @@ pub enum RawStreamingChoice<R> {
|
|||
}
|
||||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamingChoice {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
@ -61,7 +63,7 @@ pub type StreamingResult<R> =
|
|||
#[cfg(target_arch = "wasm32")]
|
||||
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<RawStreamingChoice, CompletionError>>>>;
|
||||
|
||||
pub struct StreamingCompletionResponse<R> {
|
||||
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
||||
inner: StreamingResult<R>,
|
||||
text: String,
|
||||
tool_calls: Vec<(String, String, serde_json::Value)>,
|
||||
|
@ -69,7 +71,7 @@ pub struct StreamingCompletionResponse<R> {
|
|||
pub response: Option<R>,
|
||||
}
|
||||
|
||||
impl<R> StreamingCompletionResponse<R> {
|
||||
impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
||||
pub fn new(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
|
||||
Self {
|
||||
inner,
|
||||
|
@ -81,37 +83,43 @@ impl<R> StreamingCompletionResponse<R> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for StreamingCompletionResponse<R> {
|
||||
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||
type Item = Result<StreamingChoice, CompletionError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.inner.poll_next(cx) {
|
||||
let stream = self.get_mut();
|
||||
|
||||
match stream.inner.as_mut().poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
||||
Poll::Ready(None) => {
|
||||
let content = vec![AssistantContent::text(self.text.clone())];
|
||||
let content = vec![AssistantContent::text(stream.text.clone())];
|
||||
|
||||
self.tool_calls
|
||||
.iter()
|
||||
.for_each(|(n, d, a)| AssistantContent::tool_call(n, derive!(), a));
|
||||
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
||||
AssistantContent::tool_call(n, d, a.clone());
|
||||
});
|
||||
|
||||
self.message = Message::Assistant {
|
||||
content: content.into(),
|
||||
}
|
||||
stream.message = Message::Assistant {
|
||||
content: OneOrMany::many(content)
|
||||
.expect("There should be at least one assistant message"),
|
||||
};
|
||||
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||
RawStreamingChoice::Message(text) => {
|
||||
self.text = format!("{}{}", self.text, text);
|
||||
Poll::Ready(Some(Ok(choice.clone())))
|
||||
stream.text = format!("{}{}", stream.text, text.clone());
|
||||
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
|
||||
}
|
||||
RawStreamingChoice::ToolCall(name, description, args) => {
|
||||
self.tool_calls
|
||||
.push((name, description, args));
|
||||
Poll::Ready(Some(Ok(choice.clone())))
|
||||
stream
|
||||
.tool_calls
|
||||
.push((name.clone(), description.clone(), args.clone()));
|
||||
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
|
||||
}
|
||||
RawStreamingChoice::FinalResponse(response) => {
|
||||
self.response = Some(response);
|
||||
stream.response = Some(response);
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
|
@ -120,7 +128,7 @@ impl<R> Stream for StreamingCompletionResponse<R> {
|
|||
}
|
||||
|
||||
/// Trait for high-level streaming prompt interface
|
||||
pub trait StreamingPrompt<R>: Send + Sync {
|
||||
pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
|
||||
/// Stream a simple prompt to the model
|
||||
fn stream_prompt(
|
||||
&self,
|
||||
|
@ -129,7 +137,7 @@ pub trait StreamingPrompt<R>: Send + Sync {
|
|||
}
|
||||
|
||||
/// Trait for high-level streaming chat interface
|
||||
pub trait StreamingChat<R>: Send + Sync {
|
||||
pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
|
||||
/// Stream a chat with history to the model
|
||||
fn stream_chat(
|
||||
&self,
|
||||
|
@ -150,18 +158,20 @@ pub trait StreamingCompletion<M: StreamingCompletionModel> {
|
|||
|
||||
/// Trait defining a streaming completion model
|
||||
pub trait StreamingCompletionModel: CompletionModel {
|
||||
type Response;
|
||||
type StreamingResponse: Clone + Unpin;
|
||||
/// Stream a completion response for the given request
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> impl Future<Output = Result<StreamingCompletionResponse<Self::Response>, CompletionError>>;
|
||||
) -> impl Future<
|
||||
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
|
||||
>;
|
||||
}
|
||||
|
||||
/// helper function to stream a completion request to stdout
|
||||
pub async fn stream_to_stdout<M: StreamingCompletionModel, R>(
|
||||
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||
agent: Agent<M>,
|
||||
stream: &mut StreamingCompletionResponse<R>,
|
||||
stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
print!("Response: ");
|
||||
while let Some(chunk) = stream.next().await {
|
||||
|
|
Loading…
Reference in New Issue