feat: update examples

This commit is contained in:
yavens 2025-04-09 15:33:26 -04:00
parent 86b84c82fb
commit 2920eb0a0e
13 changed files with 134 additions and 55 deletions

View File

@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage.output_tokens);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -19,5 +19,10 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage_metadata.total_token_count);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage_metadata.total_token_count);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
stream_to_stdout(agent, &mut stream).await?; stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.eval_count);
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -20,6 +20,8 @@ async fn main() -> Result<(), anyhow::Error> {
if let Some(response) = stream.response { if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage) println!("Usage: {:?}", response.usage)
}; };
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5"); println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?; stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?}", response.usage)
};
println!("Message: {:?}", stream.message);
Ok(()) Ok(())
} }

View File

@ -609,7 +609,7 @@ pub mod gemini_api_types {
HarmCategoryCivicIntegrity, HarmCategoryCivicIntegrity,
} }
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone, Default)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct UsageMetadata { pub struct UsageMetadata {
pub prompt_token_count: i32, pub prompt_token_count: i32,

View File

@ -9,18 +9,24 @@ use crate::{
streaming::{self, StreamingCompletionModel}, streaming::{self, StreamingCompletionModel},
}; };
#[derive(Debug, Deserialize, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PartialUsage {
pub total_token_count: i32,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct StreamGenerateContentResponse { pub struct StreamGenerateContentResponse {
/// Candidate responses from the model. /// Candidate responses from the model.
pub candidates: Vec<ContentCandidate>, pub candidates: Vec<ContentCandidate>,
pub model_version: Option<String>, pub model_version: Option<String>,
pub usage_metadata: UsageMetadata, pub usage_metadata: Option<PartialUsage>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct StreamingCompletionResponse { pub struct StreamingCompletionResponse {
pub usage_metadata: UsageMetadata, pub usage_metadata: PartialUsage,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
@ -90,7 +96,9 @@ impl StreamingCompletionModel for CompletionModel {
if choice.finish_reason.is_some() { if choice.finish_reason.is_some() {
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage_metadata: data.usage_metadata, usage_metadata: PartialUsage {
total_token_count: data.usage_metadata.unwrap().total_token_count,
}
})) }))
} }
} }

View File

@ -486,6 +486,18 @@ impl StreamingCompletionModel for CompletionModel {
continue; continue;
} }
} }
if response.done {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
total_duration: response.total_duration,
load_duration: response.load_duration,
prompt_eval_count: response.prompt_eval_count,
prompt_eval_duration: response.prompt_eval_duration,
eval_count: response.eval_count,
eval_duration: response.eval_duration,
done_reason: response.done_reason,
}));
}
} }
} }
}); });

View File

@ -46,12 +46,11 @@ struct StreamingChoice {
struct StreamingCompletionChunk { struct StreamingCompletionChunk {
choices: Vec<StreamingChoice>, choices: Vec<StreamingChoice>,
usage: Option<Usage>, usage: Option<Usage>,
finish_reason: Option<String>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct StreamingCompletionResponse { pub struct StreamingCompletionResponse {
pub usage: Option<Usage>, pub usage: Usage,
} }
impl StreamingCompletionModel for CompletionModel { impl StreamingCompletionModel for CompletionModel {
@ -62,7 +61,10 @@ impl StreamingCompletionModel for CompletionModel {
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> ) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{ {
let mut request = self.create_completion_request(completion_request)?; let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true})); request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
);
let builder = self.client.post("/chat/completions").json(&request); let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await send_compatible_streaming_request(builder).await
@ -86,6 +88,11 @@ pub async fn send_compatible_streaming_request(
let inner = Box::pin(stream! { let inner = Box::pin(stream! {
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
let mut final_usage = Usage {
prompt_tokens: 0,
total_tokens: 0
};
let mut partial_data = None; let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new(); let mut calls: HashMap<usize, (String, String)> = HashMap::new();
@ -110,8 +117,6 @@ pub async fn send_compatible_streaming_request(
for line in text.lines() { for line in text.lines() {
let mut line = line.to_string(); let mut line = line.to_string();
// If there was a remaining part, concat with current line // If there was a remaining part, concat with current line
if partial_data.is_some() { if partial_data.is_some() {
line = format!("{}{}", partial_data.unwrap(), line); line = format!("{}{}", partial_data.unwrap(), line);
@ -137,56 +142,53 @@ pub async fn send_compatible_streaming_request(
continue; continue;
}; };
let choice = data.choices.first().expect("Should have at least one choice");
let delta = &choice.delta; if let Some(choice) = data.choices.first() {
if !delta.tool_calls.is_empty() { let delta = &choice.delta;
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
// Start of tool call if !delta.tool_calls.is_empty() {
// name: Some(String) for tool_call in &delta.tool_calls {
// arguments: None let function = tool_call.function.clone();
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
continue;
};
let new_arguments = &tool_call.function.arguments; // Start of tool call
let arguments = format!("{}{}", arguments, new_arguments); // name: Some(String)
// arguments: None
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
continue;
};
calls.insert(tool_call.index, (name.clone(), arguments)); let new_arguments = &tool_call.function.arguments;
} let arguments = format!("{}{}", arguments, new_arguments);
// Entire tool call
else {
let name = function.name.unwrap();
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) calls.insert(tool_call.index, (name.clone(), arguments));
}
// Entire tool call
else {
let name = function.name.unwrap();
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
}
} }
} }
} }
if let Some(content) = &choice.delta.content { if let Some(usage) = data.usage {
yield Ok(streaming::RawStreamingChoice::Message(content.clone())) final_usage = usage.clone();
} }
if data.finish_reason.is_some() {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: data.usage
}))
}
} }
} }
@ -195,8 +197,12 @@ pub async fn send_compatible_streaming_request(
continue; continue;
}; };
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments)) yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
} }
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: final_usage.clone()
}))
}); });
Ok(streaming::StreamingCompletionResponse::new(inner)) Ok(streaming::StreamingCompletionResponse::new(inner))

View File

@ -91,21 +91,25 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
match stream.inner.as_mut().poll_next(cx) { match stream.inner.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(None) => { Poll::Ready(None) => {
let content = vec![AssistantContent::text(stream.text.clone())];
let mut content = vec![];
stream.tool_calls.iter().for_each(|(n, d, a)| { stream.tool_calls.iter().for_each(|(n, d, a)| {
AssistantContent::tool_call(n, d, a.clone()); content.push(AssistantContent::tool_call(n, d, a.clone()));
}); });
if content.len() == 0 || stream.text.len() > 0 {
content.insert(0, AssistantContent::text(stream.text.clone()));
}
stream.message = Message::Assistant { stream.message = Message::Assistant {
content: OneOrMany::many(content) content: OneOrMany::many(content)
.expect("There should be at least one assistant message"), .expect("There should be at least one assistant message"),
}; };
Poll::Ready(None) Poll::Ready(None)
} },
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(choice))) => match choice { Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => { RawStreamingChoice::Message(text) => {
@ -120,7 +124,8 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
} }
RawStreamingChoice::FinalResponse(response) => { RawStreamingChoice::FinalResponse(response) => {
stream.response = Some(response); stream.response = Some(response);
Poll::Pending
stream.poll_next_unpin(cx)
} }
}, },
} }