mirror of https://github.com/0xplaygrounds/rig
feat: cohere streaming + unify StreamingChoice w/ message
This commit is contained in:
parent
cad584455a
commit
0abd7b8c76
|
@ -0,0 +1,27 @@
|
||||||
|
use rig::providers::cohere;
|
||||||
|
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
// Create streaming agent with a single context prompt
|
||||||
|
let agent = cohere::Client::from_env()
|
||||||
|
.agent(cohere::COMMAND)
|
||||||
|
.preamble("Be precise and concise.")
|
||||||
|
.temperature(0.5)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Stream the response and print chunks as they arrive
|
||||||
|
let mut stream = agent
|
||||||
|
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
stream_to_stdout(agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use rig::streaming::stream_to_stdout;
|
||||||
|
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Adder;
|
||||||
|
impl Tool for Adder {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: "add".to_string(),
|
||||||
|
description: "Add x and y together".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Subtract;
|
||||||
|
impl Tool for Subtract {
|
||||||
|
const NAME: &'static str = "subtract";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "subtract",
|
||||||
|
"description": "Subtract y from x (i.e.: x - y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract from"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x - args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
// Create agent with a single context prompt and two tools
|
||||||
|
let calculator_agent = providers::cohere::Client::from_env()
|
||||||
|
.agent(providers::cohere::COMMAND_R)
|
||||||
|
.preamble(
|
||||||
|
"You are a calculator here to help the user perform arithmetic
|
||||||
|
operations. Use the tools provided to answer the user's question.
|
||||||
|
make your answer long, so we can test the streaming functionality,
|
||||||
|
like 20 words",
|
||||||
|
)
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Adder)
|
||||||
|
.tool(Subtract)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Calculate 2 - 5");
|
||||||
|
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||||
|
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||||
|
|
||||||
|
if let Some(response) = stream.response {
|
||||||
|
println!("Usage: {:?} tokens", response.usage);
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Message: {:?}", stream.message);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -6,8 +6,9 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::client::Client;
|
use super::client::Client;
|
||||||
|
use crate::completion::CompletionRequest;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CompletionResponse {
|
pub struct CompletionResponse {
|
||||||
|
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CompletionModel {
|
pub struct CompletionModel {
|
||||||
client: Client,
|
pub(crate) client: Client,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -430,16 +431,11 @@ impl CompletionModel {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl completion::CompletionModel for CompletionModel {
|
pub(crate) fn create_completion_request(
|
||||||
type Response = CompletionResponse;
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "worker", worker::send)]
|
|
||||||
async fn completion(
|
|
||||||
&self,
|
&self,
|
||||||
completion_request: completion::CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
let prompt = completion_request.prompt_with_context();
|
let prompt = completion_request.prompt_with_context();
|
||||||
|
|
||||||
let mut messages: Vec<message::Message> =
|
let mut messages: Vec<message::Message> =
|
||||||
|
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(ref params) = completion_request.additional_params {
|
||||||
|
Ok(json_utils::merge(request.clone(), params.clone()))
|
||||||
|
} else {
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl completion::CompletionModel for CompletionModel {
|
||||||
|
type Response = CompletionResponse;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "worker", worker::send)]
|
||||||
|
async fn completion(
|
||||||
|
&self,
|
||||||
|
completion_request: completion::CompletionRequest,
|
||||||
|
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||||
|
let request = self.create_completion_request(completion_request)?;
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Cohere request: {}",
|
"Cohere request: {}",
|
||||||
serde_json::to_string_pretty(&request)?
|
serde_json::to_string_pretty(&request)?
|
||||||
);
|
);
|
||||||
|
|
||||||
let response = self
|
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||||
.client
|
|
||||||
.post("/v2/chat")
|
|
||||||
.json(
|
|
||||||
&if let Some(ref params) = completion_request.additional_params {
|
|
||||||
json_utils::merge(request.clone(), params.clone())
|
|
||||||
} else {
|
|
||||||
request.clone()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let text_response = response.text().await?;
|
let text_response = response.text().await?;
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embeddings;
|
pub mod embeddings;
|
||||||
|
pub mod streaming;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use client::{ApiErrorResponse, ApiResponse};
|
pub use client::{ApiErrorResponse, ApiResponse};
|
||||||
|
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
|
||||||
// ================================================================
|
// ================================================================
|
||||||
|
|
||||||
/// `command-r-plus` completion model
|
/// `command-r-plus` completion model
|
||||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
pub const COMMAND_R_PLUS: &str = "command-r-plus";
|
||||||
/// `command-r` completion model
|
/// `command-r` completion model
|
||||||
pub const COMMAND_R: &str = "command-r";
|
pub const COMMAND_R: &str = "command-r";
|
||||||
/// `command` completion model
|
/// `command` completion model
|
||||||
|
|
|
@ -0,0 +1,195 @@
|
||||||
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
|
use crate::message::{ToolCall, ToolFunction};
|
||||||
|
use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage};
|
||||||
|
use crate::providers::cohere::CompletionModel;
|
||||||
|
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||||
|
use crate::{json_utils, streaming};
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::future::Future;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case", tag = "type")]
|
||||||
|
pub enum StreamingEvent {
|
||||||
|
MessageStart,
|
||||||
|
ContentStart,
|
||||||
|
ContentDelta { delta: Option<Delta> },
|
||||||
|
ContentEnd,
|
||||||
|
ToolPlan { delta: Option<Delta> },
|
||||||
|
ToolCallStart { delta: Option<Delta> },
|
||||||
|
ToolCallDelta { delta: Option<Delta> },
|
||||||
|
ToolCallEnd { delta: Option<Delta> },
|
||||||
|
MessageEnd { delta: Option<MessageEndDelta> },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageContentDelta {
|
||||||
|
r#type: Option<String>,
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageToolFunctionDelta {
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageToolCallDelta {
|
||||||
|
id: Option<String>,
|
||||||
|
r#type: Option<String>,
|
||||||
|
function: Option<MessageToolFunctionDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageDelta {
|
||||||
|
content: Option<MessageContentDelta>,
|
||||||
|
tool_plan: Option<String>,
|
||||||
|
tool_calls: Option<MessageToolCallDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct Delta {
|
||||||
|
message: Option<MessageDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageEndDelta {
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamingCompletionResponse {
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingCompletionModel for CompletionModel {
|
||||||
|
type StreamingResponse = StreamingCompletionResponse;
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||||
|
{
|
||||||
|
let request = self.create_completion_request(request)?;
|
||||||
|
let request = json_utils::merge(request, json!({"stream": true}));
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Cohere request: {}",
|
||||||
|
serde_json::to_string_pretty(&request)?
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(CompletionError::ProviderError(format!(
|
||||||
|
"{}: {}",
|
||||||
|
response.status(),
|
||||||
|
response.text().await?
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = Box::pin(stream! {
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut current_tool_call: Option<(String, String, String)> = None;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
let chunk = match chunk_result {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::from(e));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = match String::from_utf8(chunk.to_vec()) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
|
||||||
|
let Some(line) = line.strip_prefix("data: ") else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let event = {
|
||||||
|
let result = serde_json::from_str::<StreamingEvent>(&line);
|
||||||
|
|
||||||
|
let Ok(event) = result else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
event
|
||||||
|
};
|
||||||
|
|
||||||
|
match event {
|
||||||
|
StreamingEvent::ContentDelta { delta: Some(delta) } => {
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(content) = &message.content else { continue; };
|
||||||
|
let Some(text) = &content.text else { continue; };
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::Message(text.clone()));
|
||||||
|
},
|
||||||
|
StreamingEvent::MessageEnd {delta: Some(delta)} => {
|
||||||
|
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||||
|
usage: delta.usage.clone()
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
|
||||||
|
// Skip the delta if there's any missing information,
|
||||||
|
// though this *should* all be present
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||||
|
let Some(id) = tool_calls.id.clone() else { continue; };
|
||||||
|
let Some(function) = &tool_calls.function else { continue; };
|
||||||
|
let Some(name) = function.name.clone() else { continue; };
|
||||||
|
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||||
|
|
||||||
|
current_tool_call = Some((id, name, arguments));
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
|
||||||
|
// Skip the delta if there's any missing information,
|
||||||
|
// though this *should* all be present
|
||||||
|
let Some(message) = &delta.message else { continue; };
|
||||||
|
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||||
|
let Some(function) = &tool_calls.function else { continue; };
|
||||||
|
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||||
|
|
||||||
|
if let Some(tc) = current_tool_call.clone() {
|
||||||
|
current_tool_call = Some((
|
||||||
|
tc.0,
|
||||||
|
tc.1,
|
||||||
|
format!("{}{}", tc.2, arguments)
|
||||||
|
));
|
||||||
|
};
|
||||||
|
},
|
||||||
|
StreamingEvent::ToolCallEnd { .. } => {
|
||||||
|
let Some(tc) = current_tool_call.clone() else { continue; };
|
||||||
|
|
||||||
|
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
|
||||||
|
|
||||||
|
yield Ok(RawStreamingChoice::ToolCall(
|
||||||
|
tc.0,
|
||||||
|
tc.1,
|
||||||
|
args
|
||||||
|
));
|
||||||
|
|
||||||
|
current_tool_call = None;
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,27 +35,6 @@ pub enum RawStreamingChoice<R: Clone> {
|
||||||
FinalResponse(R),
|
FinalResponse(R),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enum representing a streaming chunk from the model
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum StreamingChoice {
|
|
||||||
/// A text chunk from a message response
|
|
||||||
Message(String),
|
|
||||||
|
|
||||||
/// A tool call response chunk
|
|
||||||
ToolCall(String, String, serde_json::Value),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for StreamingChoice {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
StreamingChoice::Message(text) => write!(f, "{}", text),
|
|
||||||
StreamingChoice::ToolCall(name, id, params) => {
|
|
||||||
write!(f, "Tool call: {} {} {:?}", name, id, params)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
pub type StreamingResult<R> =
|
pub type StreamingResult<R> =
|
||||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||||
|
@ -85,7 +64,7 @@ impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
type Item = Result<StreamingChoice, CompletionError>;
|
type Item = Result<AssistantContent, CompletionError>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
let stream = self.get_mut();
|
let stream = self.get_mut();
|
||||||
|
@ -114,13 +93,13 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||||
RawStreamingChoice::Message(text) => {
|
RawStreamingChoice::Message(text) => {
|
||||||
stream.text = format!("{}{}", stream.text, text.clone());
|
stream.text = format!("{}{}", stream.text, text.clone());
|
||||||
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
|
Poll::Ready(Some(Ok(AssistantContent::text(text))))
|
||||||
}
|
}
|
||||||
RawStreamingChoice::ToolCall(name, description, args) => {
|
RawStreamingChoice::ToolCall(id, name, args) => {
|
||||||
stream
|
stream
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.push((name.clone(), description.clone(), args.clone()));
|
.push((id.clone(), name.clone(), args.clone()));
|
||||||
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
|
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
|
||||||
}
|
}
|
||||||
RawStreamingChoice::FinalResponse(response) => {
|
RawStreamingChoice::FinalResponse(response) => {
|
||||||
stream.response = Some(response);
|
stream.response = Some(response);
|
||||||
|
@ -181,14 +160,17 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
||||||
print!("Response: ");
|
print!("Response: ");
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
match chunk {
|
match chunk {
|
||||||
Ok(StreamingChoice::Message(text)) => {
|
Ok(AssistantContent::Text(text)) => {
|
||||||
print!("{}", text);
|
print!("{}", text.text);
|
||||||
std::io::Write::flush(&mut std::io::stdout())?;
|
std::io::Write::flush(&mut std::io::stdout())?;
|
||||||
}
|
}
|
||||||
Ok(StreamingChoice::ToolCall(name, _, params)) => {
|
Ok(AssistantContent::ToolCall(tool_call)) => {
|
||||||
let res = agent
|
let res = agent
|
||||||
.tools
|
.tools
|
||||||
.call(&name, params.to_string())
|
.call(
|
||||||
|
&tool_call.function.name,
|
||||||
|
tool_call.function.arguments.to_string(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
println!("\nResult: {}", res);
|
println!("\nResult: {}", res);
|
||||||
|
|
Loading…
Reference in New Issue