mirror of https://github.com/0xplaygrounds/rig
feat: update examples
This commit is contained in:
parent
86b84c82fb
commit
2920eb0a0e
|
@ -19,5 +19,13 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage.output_tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,5 +19,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage_metadata.total_token_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage_metadata.total_token_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,5 +17,10 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
stream_to_stdout(agent, &mut stream).await?;
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.eval_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.eval_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
if let Some(response) = stream.response {
|
if let Some(response) = stream.response {
|
||||||
println!("Usage: {:?}", response.usage)
|
println!("Usage: {:?}", response.usage)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,5 +107,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
println!("Calculate 2 - 5");
|
println!("Calculate 2 - 5");
|
||||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?}", response.usage)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -609,7 +609,7 @@ pub mod gemini_api_types {
|
||||||
HarmCategoryCivicIntegrity,
|
HarmCategoryCivicIntegrity,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone, Default)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct UsageMetadata {
|
pub struct UsageMetadata {
|
||||||
pub prompt_token_count: i32,
|
pub prompt_token_count: i32,
|
||||||
|
|
|
@ -9,18 +9,24 @@ use crate::{
|
||||||
streaming::{self, StreamingCompletionModel},
|
streaming::{self, StreamingCompletionModel},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Default, Clone)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PartialUsage {
|
||||||
|
pub total_token_count: i32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct StreamGenerateContentResponse {
|
pub struct StreamGenerateContentResponse {
|
||||||
/// Candidate responses from the model.
|
/// Candidate responses from the model.
|
||||||
pub candidates: Vec<ContentCandidate>,
|
pub candidates: Vec<ContentCandidate>,
|
||||||
pub model_version: Option<String>,
|
pub model_version: Option<String>,
|
||||||
pub usage_metadata: UsageMetadata,
|
pub usage_metadata: Option<PartialUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StreamingCompletionResponse {
|
pub struct StreamingCompletionResponse {
|
||||||
pub usage_metadata: UsageMetadata,
|
pub usage_metadata: PartialUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
@ -90,7 +96,9 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
|
|
||||||
if choice.finish_reason.is_some() {
|
if choice.finish_reason.is_some() {
|
||||||
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
usage_metadata: data.usage_metadata,
|
usage_metadata: PartialUsage {
|
||||||
|
total_token_count: data.usage_metadata.unwrap().total_token_count,
|
||||||
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -486,6 +486,18 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.done {
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
total_duration: response.total_duration,
|
||||||
|
load_duration: response.load_duration,
|
||||||
|
prompt_eval_count: response.prompt_eval_count,
|
||||||
|
prompt_eval_duration: response.prompt_eval_duration,
|
||||||
|
eval_count: response.eval_count,
|
||||||
|
eval_duration: response.eval_duration,
|
||||||
|
done_reason: response.done_reason,
|
||||||
|
}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -46,12 +46,11 @@ struct StreamingChoice {
|
||||||
struct StreamingCompletionChunk {
|
struct StreamingCompletionChunk {
|
||||||
choices: Vec<StreamingChoice>,
|
choices: Vec<StreamingChoice>,
|
||||||
usage: Option<Usage>,
|
usage: Option<Usage>,
|
||||||
finish_reason: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StreamingCompletionResponse {
|
pub struct StreamingCompletionResponse {
|
||||||
pub usage: Option<Usage>,
|
pub usage: Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingCompletionModel for CompletionModel {
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
@ -62,7 +61,10 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
{
|
{
|
||||||
let mut request = self.create_completion_request(completion_request)?;
|
let mut request = self.create_completion_request(completion_request)?;
|
||||||
request = merge(request, json!({"stream": true}));
|
request = merge(
|
||||||
|
request,
|
||||||
|
json!({"stream": true, "stream_options": {"include_usage": true}}),
|
||||||
|
);
|
||||||
|
|
||||||
let builder = self.client.post("/chat/completions").json(&request);
|
let builder = self.client.post("/chat/completions").json(&request);
|
||||||
send_compatible_streaming_request(builder).await
|
send_compatible_streaming_request(builder).await
|
||||||
|
@ -86,6 +88,11 @@ pub async fn send_compatible_streaming_request(
|
||||||
let inner = Box::pin(stream! {
|
let inner = Box::pin(stream! {
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
|
||||||
|
let mut final_usage = Usage {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
total_tokens: 0
|
||||||
|
};
|
||||||
|
|
||||||
let mut partial_data = None;
|
let mut partial_data = None;
|
||||||
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
||||||
|
|
||||||
|
@ -110,8 +117,6 @@ pub async fn send_compatible_streaming_request(
|
||||||
for line in text.lines() {
|
for line in text.lines() {
|
||||||
let mut line = line.to_string();
|
let mut line = line.to_string();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// If there was a remaining part, concat with current line
|
// If there was a remaining part, concat with current line
|
||||||
if partial_data.is_some() {
|
if partial_data.is_some() {
|
||||||
line = format!("{}{}", partial_data.unwrap(), line);
|
line = format!("{}{}", partial_data.unwrap(), line);
|
||||||
|
@ -137,56 +142,53 @@ pub async fn send_compatible_streaming_request(
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let choice = data.choices.first().expect("Should have at least one choice");
|
|
||||||
|
|
||||||
let delta = &choice.delta;
|
if let Some(choice) = data.choices.first() {
|
||||||
|
|
||||||
if !delta.tool_calls.is_empty() {
|
let delta = &choice.delta;
|
||||||
for tool_call in &delta.tool_calls {
|
|
||||||
let function = tool_call.function.clone();
|
|
||||||
|
|
||||||
// Start of tool call
|
if !delta.tool_calls.is_empty() {
|
||||||
// name: Some(String)
|
for tool_call in &delta.tool_calls {
|
||||||
// arguments: None
|
let function = tool_call.function.clone();
|
||||||
if function.name.is_some() && function.arguments.is_empty() {
|
|
||||||
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
|
||||||
}
|
|
||||||
// Part of tool call
|
|
||||||
// name: None
|
|
||||||
// arguments: Some(String)
|
|
||||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
|
||||||
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let new_arguments = &tool_call.function.arguments;
|
// Start of tool call
|
||||||
let arguments = format!("{}{}", arguments, new_arguments);
|
// name: Some(String)
|
||||||
|
// arguments: None
|
||||||
|
if function.name.is_some() && function.arguments.is_empty() {
|
||||||
|
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
||||||
|
}
|
||||||
|
// Part of tool call
|
||||||
|
// name: None
|
||||||
|
// arguments: Some(String)
|
||||||
|
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||||
|
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
calls.insert(tool_call.index, (name.clone(), arguments));
|
let new_arguments = &tool_call.function.arguments;
|
||||||
}
|
let arguments = format!("{}{}", arguments, new_arguments);
|
||||||
// Entire tool call
|
|
||||||
else {
|
|
||||||
let name = function.name.unwrap();
|
|
||||||
let arguments = function.arguments;
|
|
||||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
calls.insert(tool_call.index, (name.clone(), arguments));
|
||||||
|
}
|
||||||
|
// Entire tool call
|
||||||
|
else {
|
||||||
|
let name = function.name.unwrap();
|
||||||
|
let arguments = function.arguments;
|
||||||
|
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(content) = &choice.delta.content {
|
if let Some(usage) = data.usage {
|
||||||
yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
|
final_usage = usage.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.finish_reason.is_some() {
|
|
||||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
|
||||||
usage: data.usage
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,8 +197,12 @@ pub async fn send_compatible_streaming_request(
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(streaming::RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
yield Ok(RawStreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: final_usage.clone()
|
||||||
|
}))
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(streaming::StreamingCompletionResponse::new(inner))
|
Ok(streaming::StreamingCompletionResponse::new(inner))
|
||||||
|
|
|
@ -91,21 +91,25 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
|
|
||||||
match stream.inner.as_mut().poll_next(cx) {
|
match stream.inner.as_mut().poll_next(cx) {
|
||||||
Poll::Pending => Poll::Pending,
|
Poll::Pending => Poll::Pending,
|
||||||
|
|
||||||
Poll::Ready(None) => {
|
Poll::Ready(None) => {
|
||||||
let content = vec![AssistantContent::text(stream.text.clone())];
|
|
||||||
|
let mut content = vec![];
|
||||||
|
|
||||||
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
||||||
AssistantContent::tool_call(n, d, a.clone());
|
content.push(AssistantContent::tool_call(n, d, a.clone()));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if content.len() == 0 || stream.text.len() > 0 {
|
||||||
|
content.insert(0, AssistantContent::text(stream.text.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
stream.message = Message::Assistant {
|
stream.message = Message::Assistant {
|
||||||
content: OneOrMany::many(content)
|
content: OneOrMany::many(content)
|
||||||
.expect("There should be at least one assistant message"),
|
.expect("There should be at least one assistant message"),
|
||||||
};
|
};
|
||||||
|
|
||||||
Poll::Ready(None)
|
Poll::Ready(None)
|
||||||
}
|
},
|
||||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||||
RawStreamingChoice::Message(text) => {
|
RawStreamingChoice::Message(text) => {
|
||||||
|
@ -120,7 +124,8 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
}
|
}
|
||||||
RawStreamingChoice::FinalResponse(response) => {
|
RawStreamingChoice::FinalResponse(response) => {
|
||||||
stream.response = Some(response);
|
stream.response = Some(response);
|
||||||
Poll::Pending
|
|
||||||
|
stream.poll_next_unpin(cx)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue