mirror of https://github.com/0xplaygrounds/rig
feat: cohere streaming + unify StreamingChoice w/ message
This commit is contained in:
parent
cad584455a
commit
0abd7b8c76
|
@ -0,0 +1,27 @@
|
|||
use rig::providers::cohere;
|
||||
use rig::streaming::{stream_to_stdout, StreamingPrompt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create streaming agent with a single context prompt
|
||||
let agent = cohere::Client::from_env()
|
||||
.agent(cohere::COMMAND)
|
||||
.preamble("Be precise and concise.")
|
||||
.temperature(0.5)
|
||||
.build();
|
||||
|
||||
// Stream the response and print chunks as they arrive
|
||||
let mut stream = agent
|
||||
.stream_prompt("When and where and what type is the next solar eclipse?")
|
||||
.await?;
|
||||
|
||||
stream_to_stdout(agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.message);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use anyhow::Result;
|
||||
use rig::streaming::stream_to_stdout;
|
||||
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt().init();
|
||||
// Create agent with a single context prompt and two tools
|
||||
let calculator_agent = providers::cohere::Client::from_env()
|
||||
.agent(providers::cohere::COMMAND_R)
|
||||
.preamble(
|
||||
"You are a calculator here to help the user perform arithmetic
|
||||
operations. Use the tools provided to answer the user's question.
|
||||
make your answer long, so we can test the streaming functionality,
|
||||
like 20 words",
|
||||
)
|
||||
.max_tokens(1024)
|
||||
.tool(Adder)
|
||||
.tool(Subtract)
|
||||
.build();
|
||||
|
||||
println!("Calculate 2 - 5");
|
||||
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
|
||||
stream_to_stdout(calculator_agent, &mut stream).await?;
|
||||
|
||||
if let Some(response) = stream.response {
|
||||
println!("Usage: {:?} tokens", response.usage);
|
||||
};
|
||||
|
||||
println!("Message: {:?}", stream.message);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -6,8 +6,9 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::client::Client;
|
||||
use crate::completion::CompletionRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
|
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub(crate) client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
|
@ -430,16 +431,11 @@ impl CompletionModel {
|
|||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
let prompt = completion_request.prompt_with_context();
|
||||
|
||||
let mut messages: Vec<message::Message> =
|
||||
|
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
|
|||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||
});
|
||||
|
||||
if let Some(ref params) = completion_request.additional_params {
|
||||
Ok(json_utils::merge(request.clone(), params.clone()))
|
||||
} else {
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
tracing::debug!(
|
||||
"Cohere request: {}",
|
||||
serde_json::to_string_pretty(&request)?
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v2/chat")
|
||||
.json(
|
||||
&if let Some(ref params) = completion_request.additional_params {
|
||||
json_utils::merge(request.clone(), params.clone())
|
||||
} else {
|
||||
request.clone()
|
||||
},
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let text_response = response.text().await?;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod embeddings;
|
||||
pub mod streaming;
|
||||
|
||||
pub use client::Client;
|
||||
pub use client::{ApiErrorResponse, ApiResponse};
|
||||
|
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
|
|||
// ================================================================
|
||||
|
||||
/// `command-r-plus` completion model
|
||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
||||
pub const COMMAND_R_PLUS: &str = "command-r-plus";
|
||||
/// `command-r` completion model
|
||||
pub const COMMAND_R: &str = "command-r";
|
||||
/// `command` completion model
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::message::{ToolCall, ToolFunction};
|
||||
use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage};
|
||||
use crate::providers::cohere::CompletionModel;
|
||||
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
|
||||
use crate::{json_utils, streaming};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case", tag = "type")]
|
||||
pub enum StreamingEvent {
|
||||
MessageStart,
|
||||
ContentStart,
|
||||
ContentDelta { delta: Option<Delta> },
|
||||
ContentEnd,
|
||||
ToolPlan { delta: Option<Delta> },
|
||||
ToolCallStart { delta: Option<Delta> },
|
||||
ToolCallDelta { delta: Option<Delta> },
|
||||
ToolCallEnd { delta: Option<Delta> },
|
||||
MessageEnd { delta: Option<MessageEndDelta> },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageContentDelta {
|
||||
r#type: Option<String>,
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageToolFunctionDelta {
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageToolCallDelta {
|
||||
id: Option<String>,
|
||||
r#type: Option<String>,
|
||||
function: Option<MessageToolFunctionDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageDelta {
|
||||
content: Option<MessageContentDelta>,
|
||||
tool_plan: Option<String>,
|
||||
tool_calls: Option<MessageToolCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Delta {
|
||||
message: Option<MessageDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageEndDelta {
|
||||
finish_reason: Option<String>,
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamingCompletionResponse {
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
type StreamingResponse = StreamingCompletionResponse;
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
|
||||
{
|
||||
let request = self.create_completion_request(request)?;
|
||||
let request = json_utils::merge(request, json!({"stream": true}));
|
||||
|
||||
tracing::debug!(
|
||||
"Cohere request: {}",
|
||||
serde_json::to_string_pretty(&request)?
|
||||
);
|
||||
|
||||
let response = self.client.post("/v2/chat").json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut current_tool_call: Option<(String, String, String)> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::from(e));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match String::from_utf8(chunk.to_vec()) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
for line in text.lines() {
|
||||
|
||||
let Some(line) = line.strip_prefix("data: ") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = {
|
||||
let result = serde_json::from_str::<StreamingEvent>(&line);
|
||||
|
||||
let Ok(event) = result else {
|
||||
continue;
|
||||
};
|
||||
|
||||
event
|
||||
};
|
||||
|
||||
match event {
|
||||
StreamingEvent::ContentDelta { delta: Some(delta) } => {
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(content) = &message.content else { continue; };
|
||||
let Some(text) = &content.text else { continue; };
|
||||
|
||||
yield Ok(RawStreamingChoice::Message(text.clone()));
|
||||
},
|
||||
StreamingEvent::MessageEnd {delta: Some(delta)} => {
|
||||
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
|
||||
usage: delta.usage.clone()
|
||||
}));
|
||||
},
|
||||
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
|
||||
// Skip the delta if there's any missing information,
|
||||
// though this *should* all be present
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||
let Some(id) = tool_calls.id.clone() else { continue; };
|
||||
let Some(function) = &tool_calls.function else { continue; };
|
||||
let Some(name) = function.name.clone() else { continue; };
|
||||
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||
|
||||
current_tool_call = Some((id, name, arguments));
|
||||
},
|
||||
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
|
||||
// Skip the delta if there's any missing information,
|
||||
// though this *should* all be present
|
||||
let Some(message) = &delta.message else { continue; };
|
||||
let Some(tool_calls) = &message.tool_calls else { continue; };
|
||||
let Some(function) = &tool_calls.function else { continue; };
|
||||
let Some(arguments) = function.arguments.clone() else { continue; };
|
||||
|
||||
if let Some(tc) = current_tool_call.clone() {
|
||||
current_tool_call = Some((
|
||||
tc.0,
|
||||
tc.1,
|
||||
format!("{}{}", tc.2, arguments)
|
||||
));
|
||||
};
|
||||
},
|
||||
StreamingEvent::ToolCallEnd { .. } => {
|
||||
let Some(tc) = current_tool_call.clone() else { continue; };
|
||||
|
||||
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
|
||||
|
||||
yield Ok(RawStreamingChoice::ToolCall(
|
||||
tc.0,
|
||||
tc.1,
|
||||
args
|
||||
));
|
||||
|
||||
current_tool_call = None;
|
||||
},
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(streaming::StreamingCompletionResponse::new(stream))
|
||||
}
|
||||
}
|
|
@ -35,27 +35,6 @@ pub enum RawStreamingChoice<R: Clone> {
|
|||
FinalResponse(R),
|
||||
}
|
||||
|
||||
/// Enum representing a streaming chunk from the model
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamingChoice {
|
||||
/// A text chunk from a message response
|
||||
Message(String),
|
||||
|
||||
/// A tool call response chunk
|
||||
ToolCall(String, String, serde_json::Value),
|
||||
}
|
||||
|
||||
impl Display for StreamingChoice {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamingChoice::Message(text) => write!(f, "{}", text),
|
||||
StreamingChoice::ToolCall(name, id, params) => {
|
||||
write!(f, "Tool call: {} {} {:?}", name, id, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub type StreamingResult<R> =
|
||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
|
||||
|
@ -85,7 +64,7 @@ impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
|
|||
}
|
||||
|
||||
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
||||
type Item = Result<StreamingChoice, CompletionError>;
|
||||
type Item = Result<AssistantContent, CompletionError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let stream = self.get_mut();
|
||||
|
@ -114,13 +93,13 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
|||
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||
RawStreamingChoice::Message(text) => {
|
||||
stream.text = format!("{}{}", stream.text, text.clone());
|
||||
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
|
||||
Poll::Ready(Some(Ok(AssistantContent::text(text))))
|
||||
}
|
||||
RawStreamingChoice::ToolCall(name, description, args) => {
|
||||
RawStreamingChoice::ToolCall(id, name, args) => {
|
||||
stream
|
||||
.tool_calls
|
||||
.push((name.clone(), description.clone(), args.clone()));
|
||||
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
|
||||
.push((id.clone(), name.clone(), args.clone()));
|
||||
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
|
||||
}
|
||||
RawStreamingChoice::FinalResponse(response) => {
|
||||
stream.response = Some(response);
|
||||
|
@ -181,14 +160,17 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
|
|||
print!("Response: ");
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(StreamingChoice::Message(text)) => {
|
||||
print!("{}", text);
|
||||
Ok(AssistantContent::Text(text)) => {
|
||||
print!("{}", text.text);
|
||||
std::io::Write::flush(&mut std::io::stdout())?;
|
||||
}
|
||||
Ok(StreamingChoice::ToolCall(name, _, params)) => {
|
||||
Ok(AssistantContent::ToolCall(tool_call)) => {
|
||||
let res = agent
|
||||
.tools
|
||||
.call(&name, params.to_string())
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
println!("\nResult: {}", res);
|
||||
|
|
Loading…
Reference in New Issue