feat: cohere streaming + unify StreamingChoice w/ message

This commit is contained in:
yavens 2025-04-10 19:55:42 -04:00
parent cad584455a
commit 0abd7b8c76
6 changed files with 378 additions and 53 deletions

View File

@ -0,0 +1,27 @@
use rig::providers::cohere;
use rig::streaming::{stream_to_stdout, StreamingPrompt};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = cohere::Client::from_env()
.agent(cohere::COMMAND)
.preamble("Be precise and concise.")
.temperature(0.5)
.build();
// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await?;
stream_to_stdout(agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.message);
Ok(())
}

View File

@ -0,0 +1,118 @@
use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::cohere::Client::from_env()
.agent(providers::cohere::COMMAND_R)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
if let Some(response) = stream.response {
println!("Usage: {:?} tokens", response.usage);
};
println!("Message: {:?}", stream.message);
Ok(())
}

View File

@ -6,8 +6,9 @@ use crate::{
};
use super::client::Client;
use crate::completion::CompletionRequest;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
@ -419,7 +420,7 @@ impl TryFrom<Message> for message::Message {
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub(crate) client: Client,
pub model: String,
}
@ -430,16 +431,11 @@ impl CompletionModel {
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
pub(crate) fn create_completion_request(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
let prompt = completion_request.prompt_with_context();
let mut messages: Vec<message::Message> =
@ -468,23 +464,29 @@ impl completion::CompletionModel for CompletionModel {
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
});
if let Some(ref params) = completion_request.additional_params {
Ok(json_utils::merge(request.clone(), params.clone()))
} else {
Ok(request)
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self
.client
.post("/v2/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
let response = self.client.post("/v2/chat").json(&request).send().await?;
if response.status().is_success() {
let text_response = response.text().await?;

View File

@ -12,6 +12,7 @@
pub mod client;
pub mod completion;
pub mod embeddings;
pub mod streaming;
pub use client::Client;
pub use client::{ApiErrorResponse, ApiResponse};
@ -23,7 +24,7 @@ pub use embeddings::EmbeddingModel;
// ================================================================
/// `command-r-plus` completion model
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
pub const COMMAND_R_PLUS: &str = "command-r-plus";
/// `command-r` completion model
pub const COMMAND_R: &str = "command-r";
/// `command` completion model

View File

@ -0,0 +1,195 @@
use crate::completion::{CompletionError, CompletionRequest};
use crate::message::{ToolCall, ToolFunction};
use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage};
use crate::providers::cohere::CompletionModel;
use crate::streaming::{RawStreamingChoice, StreamingCompletionModel};
use crate::{json_utils, streaming};
use async_stream::stream;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::future::Future;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case", tag = "type")]
pub enum StreamingEvent {
MessageStart,
ContentStart,
ContentDelta { delta: Option<Delta> },
ContentEnd,
ToolPlan { delta: Option<Delta> },
ToolCallStart { delta: Option<Delta> },
ToolCallDelta { delta: Option<Delta> },
ToolCallEnd { delta: Option<Delta> },
MessageEnd { delta: Option<MessageEndDelta> },
}
#[derive(Debug, Deserialize)]
struct MessageContentDelta {
r#type: Option<String>,
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolFunctionDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageToolCallDelta {
id: Option<String>,
r#type: Option<String>,
function: Option<MessageToolFunctionDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageDelta {
content: Option<MessageContentDelta>,
tool_plan: Option<String>,
tool_calls: Option<MessageToolCallDelta>,
}
#[derive(Debug, Deserialize)]
struct Delta {
message: Option<MessageDelta>,
}
#[derive(Debug, Deserialize)]
struct MessageEndDelta {
finish_reason: Option<String>,
usage: Option<Usage>,
}
#[derive(Clone)]
pub struct StreamingCompletionResponse {
pub usage: Option<Usage>,
}
impl StreamingCompletionModel for CompletionModel {
type StreamingResponse = StreamingCompletionResponse;
async fn stream(
&self,
request: CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
let request = self.create_completion_request(request)?;
let request = json_utils::merge(request, json!({"stream": true}));
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self.client.post("/v2/chat").json(&request).send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let stream = Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut current_tool_call: Option<(String, String, String)> = None;
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let Some(line) = line.strip_prefix("data: ") else {
continue;
};
let event = {
let result = serde_json::from_str::<StreamingEvent>(&line);
let Ok(event) = result else {
continue;
};
event
};
match event {
StreamingEvent::ContentDelta { delta: Some(delta) } => {
let Some(message) = &delta.message else { continue; };
let Some(content) = &message.content else { continue; };
let Some(text) = &content.text else { continue; };
yield Ok(RawStreamingChoice::Message(text.clone()));
},
StreamingEvent::MessageEnd {delta: Some(delta)} => {
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: delta.usage.clone()
}));
},
StreamingEvent::ToolCallStart { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(id) = tool_calls.id.clone() else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(name) = function.name.clone() else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
current_tool_call = Some((id, name, arguments));
},
StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
// Skip the delta if there's any missing information,
// though this *should* all be present
let Some(message) = &delta.message else { continue; };
let Some(tool_calls) = &message.tool_calls else { continue; };
let Some(function) = &tool_calls.function else { continue; };
let Some(arguments) = function.arguments.clone() else { continue; };
if let Some(tc) = current_tool_call.clone() {
current_tool_call = Some((
tc.0,
tc.1,
format!("{}{}", tc.2, arguments)
));
};
},
StreamingEvent::ToolCallEnd { .. } => {
let Some(tc) = current_tool_call.clone() else { continue; };
let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
yield Ok(RawStreamingChoice::ToolCall(
tc.0,
tc.1,
args
));
current_tool_call = None;
},
_ => {}
};
}
}
});
Ok(streaming::StreamingCompletionResponse::new(stream))
}
}

View File

@ -35,27 +35,6 @@ pub enum RawStreamingChoice<R: Clone> {
FinalResponse(R),
}
/// Enum representing a streaming chunk from the model
#[derive(Debug, Clone)]
pub enum StreamingChoice {
/// A text chunk from a message response
Message(String),
/// A tool call response chunk
ToolCall(String, String, serde_json::Value),
}
impl Display for StreamingChoice {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
StreamingChoice::Message(text) => write!(f, "{}", text),
StreamingChoice::ToolCall(name, id, params) => {
write!(f, "Tool call: {} {} {:?}", name, id, params)
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
@ -85,7 +64,7 @@ impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
}
impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
type Item = Result<StreamingChoice, CompletionError>;
type Item = Result<AssistantContent, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
@ -114,13 +93,13 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
Poll::Ready(Some(Ok(choice))) => match choice {
RawStreamingChoice::Message(text) => {
stream.text = format!("{}{}", stream.text, text.clone());
Poll::Ready(Some(Ok(StreamingChoice::Message(text))))
Poll::Ready(Some(Ok(AssistantContent::text(text))))
}
RawStreamingChoice::ToolCall(name, description, args) => {
RawStreamingChoice::ToolCall(id, name, args) => {
stream
.tool_calls
.push((name.clone(), description.clone(), args.clone()));
Poll::Ready(Some(Ok(StreamingChoice::ToolCall(name, description, args))))
.push((id.clone(), name.clone(), args.clone()));
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
}
RawStreamingChoice::FinalResponse(response) => {
stream.response = Some(response);
@ -181,14 +160,17 @@ pub async fn stream_to_stdout<M: StreamingCompletionModel>(
print!("Response: ");
while let Some(chunk) = stream.next().await {
match chunk {
Ok(StreamingChoice::Message(text)) => {
print!("{}", text);
Ok(AssistantContent::Text(text)) => {
print!("{}", text.text);
std::io::Write::flush(&mut std::io::stdout())?;
}
Ok(StreamingChoice::ToolCall(name, _, params)) => {
Ok(AssistantContent::ToolCall(tool_call)) => {
let res = agent
.tools
.call(&name, params.to_string())
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await
.map_err(|e| std::io::Error::other(e.to_string()))?;
println!("\nResult: {}", res);