feat(provider): cohere-v2 (#350)

* fix(cohere): partial v2

* feat(cohere): partial

* feat: cohere v2 impl

* fix(cohere): working cohere impl
This commit is contained in:
Mochan 2025-03-17 17:13:54 -07:00 committed by GitHub
parent 19c095122f
commit 599f8b46c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1054 additions and 673 deletions

View File

@ -0,0 +1,125 @@
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();
let client = providers::cohere::Client::from_env();
let agent = client
.agent("command-r")
.preamble("You are a helpful assistant.")
.build();
let answer = agent.prompt("Tell me a joke").await?;
println!("Answer: {}", answer);
// Create agent with a single context prompt and two tools
let calculator_agent = client
.agent(providers::cohere::COMMAND_R)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
// Prompt the agent and print the response
println!("Calculate 2 - 5");
println!(
"Cohere Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);
Ok(())
}
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

View File

@ -1,50 +0,0 @@
use std::env;
use rig::{
completion::{Completion, Prompt},
providers::cohere::Client as CohereClient,
};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create Cohere client
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = CohereClient::new(&cohere_api_key);
let klimadao_agent = cohere_client
.agent("command-r")
.temperature(0.0)
.additional_params(json!({
"connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}]
}))
.build();
// Prompt the model and print the response
// We use `prompt` to get a simple response from the model as a String
let response = klimadao_agent.prompt("Tell me about BCT tokens?").await?;
println!("\n\nCoral: {:?}", response);
// Prompt the model and get the citations
// We use `completion` to allow use to customize the request further and
// get a more detailed response from the model.
// Here the response is of type CompletionResponse<cohere::CompletionResponse>
// which contains `choice` (Message or ToolCall) as well as `raw_response`,
// the underlying providers' raw response.
let response = klimadao_agent
.completion("Tell me about BCT tokens?", vec![])
.await?
.additional_params(json!({
"connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}]
}))
.send()
.await?;
println!(
"\n\nCoral: {:?}\n\nCitations:\n{:?}",
response.choice, response.raw_response.citations
);
Ok(())
}

View File

@ -15,6 +15,11 @@ struct Debater {
impl Debater {
fn new(position_a: &str, position_b: &str) -> Self {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
let openai_client =
openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"));
let cohere_client =

View File

@ -1,623 +0,0 @@
//! Cohere API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::cohere;
//!
//! let client = cohere::Client::new("YOUR_API_KEY");
//!
//! let command_r = client.completion_model(cohere::COMMAND_R);
//! ```
use std::collections::HashMap;
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError},
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils, message, Embed, OneOrMany,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
// ================================================================
// Main Cohere Client
// ================================================================
const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, COHERE_API_BASE_URL)
}
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Cohere reqwest client should build"),
}
}
/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
let ndims = match model {
EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
EMBED_ENGLISH_V2 => 4096,
EMBED_MULTILINGUAL_V2 => 768,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
pub fn embedding_model_with_ndims(
&self,
model: &str,
input_type: &str,
ndims: usize,
) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
pub fn embeddings<D: Embed>(
&self,
model: &str,
input_type: &str,
) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
// ================================================================
// Cohere Embedding API
// ================================================================
/// `embed-english-v3.0` embedding model
pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
/// `embed-english-light-v3.0` embedding model
pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
/// `embed-multilingual-v3.0` embedding model
pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
/// `embed-multilingual-light-v3.0` embedding model
pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
/// `embed-english-v2.0` embedding model
pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
/// `embed-english-light-v2.0` embedding model
pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
/// `embed-multilingual-v2.0` embedding model
pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
#[derive(Deserialize)]
pub struct EmbeddingResponse {
#[serde(default)]
pub response_type: Option<String>,
pub id: String,
pub embeddings: Vec<Vec<f64>>,
pub texts: Vec<String>,
#[serde(default)]
pub meta: Option<Meta>,
}
#[derive(Deserialize)]
pub struct Meta {
pub api_version: ApiVersion,
pub billed_units: BilledUnits,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Deserialize)]
pub struct ApiVersion {
pub version: String,
#[serde(default)]
pub is_deprecated: Option<bool>,
#[serde(default)]
pub is_experimental: Option<bool>,
}
#[derive(Deserialize, Debug)]
pub struct BilledUnits {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default)]
pub search_units: u32,
#[serde(default)]
pub classifications: u32,
}
impl std::fmt::Display for BilledUnits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
self.input_tokens, self.output_tokens, self.search_units, self.classifications
)
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
pub input_type: String,
ndims: usize,
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 96;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/v1/embed")
.json(&json!({
"model": self.model,
"texts": documents,
"input_type": self.input_type,
}))
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
match response.meta {
Some(meta) => tracing::info!(target: "rig",
"Cohere embeddings billed units: {}",
meta.billed_units,
),
None => tracing::info!(target: "rig",
"Cohere embeddings billed units: n/a",
),
};
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(
format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)
.into(),
));
}
Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
input_type: input_type.to_string(),
ndims,
}
}
}
// ================================================================
// Cohere Completion API
// ================================================================
/// `command-r-plus` completion model
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
/// `command-r` completion model
pub const COMMAND_R: &str = "command-r";
/// `command` completion model
pub const COMMAND: &str = "command";
/// `command-nightly` completion model
pub const COMMAND_NIGHTLY: &str = "command-nightly";
/// `command-light` completion model
pub const COMMAND_LIGHT: &str = "command-light";
/// `command-light-nightly` completion model
pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub text: String,
pub generation_id: String,
#[serde(default)]
pub citations: Vec<Citation>,
#[serde(default)]
pub documents: Vec<Document>,
#[serde(default)]
pub is_search_required: Option<bool>,
#[serde(default)]
pub search_queries: Vec<SearchQuery>,
#[serde(default)]
pub search_results: Vec<SearchResult>,
pub finish_reason: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
#[serde(default)]
pub chat_history: Vec<ChatHistory>,
}
impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
fn from(response: CompletionResponse) -> Self {
let CompletionResponse {
text, tool_calls, ..
} = &response;
let model_response = if !tool_calls.is_empty() {
tool_calls
.iter()
.map(|tool_call| {
completion::AssistantContent::tool_call(
tool_call.name.clone(),
tool_call.name.clone(),
tool_call.parameters.clone(),
)
})
.collect::<Vec<_>>()
} else {
vec![completion::AssistantContent::text(text.clone())]
};
completion::CompletionResponse {
choice: OneOrMany::many(model_response).expect("There is atleast one content"),
raw_response: response,
}
}
}
#[derive(Debug, Deserialize)]
pub struct Citation {
pub start: u32,
pub end: u32,
pub text: String,
pub document_ids: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct Document {
pub id: String,
#[serde(flatten)]
pub additional_prop: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct SearchQuery {
pub text: String,
pub generation_id: String,
}
#[derive(Debug, Deserialize)]
pub struct SearchResult {
pub search_query: SearchQuery,
pub connector: Connector,
pub document_ids: Vec<String>,
#[serde(default)]
pub error_message: Option<String>,
#[serde(default)]
pub continue_on_failure: bool,
}
#[derive(Debug, Deserialize)]
pub struct Connector {
pub id: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCall {
pub name: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct ChatHistory {
pub role: String,
pub message: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Parameter {
pub description: String,
pub r#type: String,
pub required: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameter_definitions: HashMap<String, Parameter>,
}
impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
fn convert_type(r#type: &serde_json::Value) -> String {
fn convert_type_str(r#type: &str) -> String {
match r#type {
"string" => "string".to_owned(),
"number" => "number".to_owned(),
"integer" => "integer".to_owned(),
"boolean" => "boolean".to_owned(),
"array" => "array".to_owned(),
"object" => "object".to_owned(),
_ => "string".to_owned(),
}
}
match r#type {
serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
serde_json::Value::Array(types) => convert_type_str(
types
.iter()
.find(|t| t.as_str() != Some("null"))
.and_then(|t| t.as_str())
.unwrap_or("string"),
),
_ => "string".to_owned(),
}
}
let maybe_required = tool
.parameters
.get("required")
.and_then(|v| v.as_array())
.map(|required| {
required
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self {
name: tool.name,
description: tool.description,
parameter_definitions: tool
.parameters
.get("properties")
.expect("Tool properties should exist")
.as_object()
.expect("Tool properties should be an object")
.iter()
.map(|(argname, argdef)| {
(
argname.clone(),
Parameter {
description: argdef
.get("description")
.expect("Argument description should exist")
.as_str()
.expect("Argument description should be a string")
.to_string(),
r#type: convert_type(
argdef.get("type").expect("Argument type should exist"),
),
required: maybe_required.contains(&argname.as_str()),
},
)
})
.collect::<HashMap<_, _>>(),
}
}
}
#[derive(Deserialize, Serialize)]
#[serde(tag = "role", rename_all = "UPPERCASE")]
pub enum Message {
User {
message: String,
tool_calls: Vec<ToolCall>,
},
Chatbot {
message: String,
tool_calls: Vec<ToolCall>,
},
Tool {
tool_results: Vec<ToolResult>,
},
/// According to the documentation, this message type should not be used
System {
content: String,
tool_calls: Vec<ToolCall>,
},
}
#[derive(Deserialize, Serialize)]
pub struct ToolResult {
pub call: ToolCall,
pub outputs: Vec<serde_json::Value>,
}
impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => content
.into_iter()
.map(|content| {
Ok(Message::User {
message: match content {
message::UserContent::Text(message::Text { text }) => text,
_ => {
return Err(message::MessageError::ConversionError(
"Only text content is supported by Cohere".to_owned(),
))
}
},
tool_calls: vec![],
})
})
.collect::<Result<Vec<_>, _>>(),
_ => Err(message::MessageError::ConversionError(
"Only user messages are supported by Cohere".to_owned(),
)),
}
}
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let chat_history = completion_request
.chat_history
.into_iter()
.map(Vec::<Message>::try_from)
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
let message = match completion_request.prompt {
message::Message::User { content } => Ok(content
.into_iter()
.map(|content| match content {
message::UserContent::Text(message::Text { text }) => Ok(text),
_ => Err(CompletionError::RequestError(
"Only text content is supported by Cohere".into(),
)),
})
.collect::<Result<Vec<_>, _>>()?
.join("\n")),
_ => Err(CompletionError::RequestError(
"Only user messages are supported by Cohere".into(),
)),
}?;
let request = json!({
"model": self.model,
"preamble": completion_request.preamble,
"message": message,
"documents": completion_request.documents,
"chat_history": chat_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
});
let response = self
.client
.post("/v1/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -0,0 +1,116 @@
use crate::{
agent::AgentBuilder, embeddings::EmbeddingsBuilder, extractor::ExtractorBuilder, Embed,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::{CompletionModel, EmbeddingModel};
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
// ================================================================
// Main Cohere Client
// ================================================================
const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, COHERE_API_BASE_URL)
}
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Cohere reqwest client should build"),
}
}
/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
let ndims = match model {
super::EMBED_ENGLISH_V3
| super::EMBED_MULTILINGUAL_V3
| super::EMBED_ENGLISH_LIGHT_V2 => 1024,
super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
super::EMBED_ENGLISH_V2 => 4096,
super::EMBED_MULTILINGUAL_V2 => 768,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
pub fn embedding_model_with_ndims(
&self,
model: &str,
input_type: &str,
ndims: usize,
) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
pub fn embeddings<D: Embed>(
&self,
model: &str,
input_type: &str,
) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

View File

@ -0,0 +1,611 @@
use std::collections::HashMap;
use crate::{
completion::{self, CompletionError},
json_utils, message, OneOrMany,
};
use super::client::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub finish_reason: FinishReason,
message: Message,
#[serde(default)]
pub usage: Option<Usage>,
}
impl CompletionResponse {
/// Return that parts of the response for assistant messages w/o dealing with the other variants
pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
let Message::Assistant {
content,
citations,
tool_calls,
..
} = self.message.clone()
else {
unreachable!("Completion responses will only return an assistant message")
};
(content, citations, tool_calls)
}
}
#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
MaxTokens,
StopSequence,
Complete,
Error,
ToolCall,
}
#[derive(Debug, Deserialize, Clone)]
pub struct Usage {
#[serde(default)]
pub billed_units: Option<BilledUnits>,
#[serde(default)]
pub tokens: Option<Tokens>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct BilledUnits {
#[serde(default)]
pub output_tokens: Option<f64>,
#[serde(default)]
pub classifications: Option<f64>,
#[serde(default)]
pub search_units: Option<f64>,
#[serde(default)]
pub input_tokens: Option<f64>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct Tokens {
#[serde(default)]
pub input_tokens: Option<f64>,
#[serde(default)]
pub output_tokens: Option<f64>,
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let (content, _, tool_calls) = response.message();
let model_response = if !tool_calls.is_empty() {
OneOrMany::many(
tool_calls
.into_iter()
.filter_map(|tool_call| {
let ToolCallFunction { name, arguments } = tool_call.function?;
let id = tool_call.id.unwrap_or_else(|| name.clone());
Some(completion::AssistantContent::tool_call(id, name, arguments))
})
.collect::<Vec<_>>(),
)
.expect("We have atleast 1 tool call in this if block")
} else {
OneOrMany::many(content.into_iter().map(|content| match content {
AssistantContent::Text { text } => completion::AssistantContent::text(text),
}))
.map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?
};
Ok(completion::CompletionResponse {
choice: OneOrMany::many(model_response).expect("There is atleast one content"),
raw_response: response,
})
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct Document {
pub id: String,
pub data: HashMap<String, serde_json::Value>,
}
impl From<completion::Document> for Document {
fn from(document: completion::Document) -> Self {
let mut data: HashMap<String, serde_json::Value> = HashMap::new();
// We use `.into()` here explicitly since the `document.additional_props` type will likely
// evolve into `serde_json::Value` in the future.
document
.additional_props
.into_iter()
.for_each(|(key, value)| {
data.insert(key, value.into());
});
data.insert("text".to_string(), document.text.into());
Self {
id: document.id,
data,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct ToolCall {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub r#type: Option<ToolType>,
#[serde(default)]
pub function: Option<ToolCallFunction>,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct ToolCallFunction {
pub name: String,
#[serde(with = "json_utils::stringified_json")]
pub arguments: serde_json::Value,
}
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct Tool {
pub r#type: ToolType,
pub function: Function,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct Function {
pub name: String,
#[serde(default)]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
impl From<completion::ToolDefinition> for Tool {
fn from(tool: completion::ToolDefinition) -> Self {
Self {
r#type: ToolType::default(),
function: Function {
name: tool.name,
description: Some(tool.description),
parameters: tool.parameters,
},
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
User {
content: OneOrMany<UserContent>,
},
Assistant {
#[serde(default)]
content: Vec<AssistantContent>,
#[serde(default)]
citations: Vec<Citation>,
#[serde(default)]
tool_calls: Vec<ToolCall>,
#[serde(default)]
tool_plan: Option<String>,
},
Tool {
content: OneOrMany<ToolResultContent>,
tool_call_id: String,
},
System {
content: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UserContent {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AssistantContent {
Text { text: String },
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub enum ToolResultContent {
Text { text: String },
Document { document: Document },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Citation {
#[serde(default)]
pub start: Option<u32>,
#[serde(default)]
pub end: Option<u32>,
#[serde(default)]
pub text: Option<String>,
#[serde(rename = "type")]
pub citation_type: Option<CitationType>,
#[serde(default)]
pub sources: Vec<Source>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum Source {
Document {
id: Option<String>,
document: Option<serde_json::Map<String, serde_json::Value>>,
},
Tool {
id: Option<String>,
tool_output: Option<serde_json::Map<String, serde_json::Value>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum CitationType {
TextContent,
Plan,
}
impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
Ok(match message {
message::Message::User { content } => content
.into_iter()
.map(|content| match content {
message::UserContent::Text(message::Text { text }) => Ok(Message::User {
content: OneOrMany::one(UserContent::Text { text }),
}),
message::UserContent::ToolResult(message::ToolResult { id, content }) => {
Ok(Message::Tool {
tool_call_id: id,
content: content.try_map(|content| match content {
message::ToolResultContent::Text(text) => {
Ok(ToolResultContent::Text { text: text.text })
}
_ => Err(message::MessageError::ConversionError(
"Only text tool result content is supported by Cohere"
.to_owned(),
)),
})?,
})
}
_ => Err(message::MessageError::ConversionError(
"Only text content is supported by Cohere".to_owned(),
)),
})
.collect::<Result<Vec<_>, _>>()?,
message::Message::Assistant { content } => {
let mut text_content = vec![];
let mut tool_calls = vec![];
content.into_iter().for_each(|content| match content {
message::AssistantContent::Text(message::Text { text }) => {
text_content.push(AssistantContent::Text { text });
}
message::AssistantContent::ToolCall(message::ToolCall {
id,
function: message::ToolFunction { name, arguments },
}) => {
tool_calls.push(ToolCall {
id: Some(id),
r#type: Some(ToolType::Function),
function: Some(ToolCallFunction {
name,
arguments: serde_json::to_value(arguments).unwrap_or_default(),
}),
});
}
});
vec![Message::Assistant {
content: text_content,
citations: vec![],
tool_calls,
tool_plan: None,
}]
}
})
}
}
impl TryFrom<Message> for message::Message {
type Error = message::MessageError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
match message {
Message::User { content } => Ok(message::Message::User {
content: content.map(|content| match content {
UserContent::Text { text } => {
message::UserContent::Text(message::Text { text })
}
UserContent::ImageUrl { image_url } => message::UserContent::image(
image_url.url,
Some(message::ContentFormat::String),
None,
None,
),
}),
}),
Message::Assistant {
content,
tool_calls,
..
} => {
let mut content = content
.into_iter()
.map(|content| match content {
AssistantContent::Text { text } => message::AssistantContent::text(text),
})
.collect::<Vec<_>>();
content.extend(tool_calls.into_iter().filter_map(|tool_call| {
let ToolCallFunction { name, arguments } = tool_call.function?;
Some(message::AssistantContent::tool_call(
tool_call.id.unwrap_or_else(|| name.clone()),
name,
arguments,
))
}));
let content = OneOrMany::many(content).map_err(|_| {
message::MessageError::ConversionError(
"Expected either text content or tool calls".to_string(),
)
})?;
Ok(message::Message::Assistant { content })
}
Message::Tool {
content,
tool_call_id,
} => {
let content = content.try_map(|content| {
Ok(match content {
ToolResultContent::Text { text } => message::ToolResultContent::text(text),
ToolResultContent::Document { document } => {
message::ToolResultContent::text(
serde_json::to_string(&document.data).map_err(|e| {
message::MessageError::ConversionError(
format!("Failed to convert tool result document content into text: {}", e),
)
})?,
)
}
})
})?;
Ok(message::Message::User {
content: OneOrMany::one(message::UserContent::tool_result(
tool_call_id,
content,
)),
})
}
Message::System { content } => Ok(message::Message::user(content)),
}
}
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let prompt = completion_request.prompt_with_context();
let mut messages: Vec<message::Message> =
if let Some(preamble) = completion_request.preamble {
vec![preamble.into()]
} else {
vec![]
};
messages.extend(completion_request.chat_history);
messages.push(prompt);
let messages: Vec<Message> = messages
.into_iter()
.map(|msg| msg.try_into())
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect();
let request = json!({
"model": self.model,
"messages": messages,
"documents": completion_request.documents,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
});
tracing::debug!(
"Cohere request: {}",
serde_json::to_string_pretty(&request)?
);
let response = self
.client
.post("/v2/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
if response.status().is_success() {
let text_response = response.text().await?;
tracing::debug!("Cohere response text: {}", text_response);
let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
let completion: completion::CompletionResponse<CompletionResponse> =
json_response.try_into()?;
Ok(completion)
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_path_to_error::deserialize;
#[test]
fn test_deserialize_completion_response() {
let json_data = r#"
{
"id": "abc123",
"message": {
"role": "assistant",
"tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
"tool_calls": [
{
"id": "subtract_sm6ps6fb6y9f",
"type": "function",
"function": {
"name": "subtract",
"arguments": "{\"x\":5,\"y\":2}"
}
}
]
},
"finish_reason": "TOOL_CALL",
"usage": {
"billed_units": {
"input_tokens": 78,
"output_tokens": 27
},
"tokens": {
"input_tokens": 1028,
"output_tokens": 63
}
}
}
"#;
let mut deserializer = serde_json::Deserializer::from_str(json_data);
let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
let response = result.unwrap();
let (_, citations, tool_calls) = response.message();
let CompletionResponse {
id,
finish_reason,
usage,
..
} = response;
assert_eq!(id, "abc123");
assert_eq!(finish_reason, FinishReason::ToolCall);
let Usage {
billed_units,
tokens,
} = usage.unwrap();
let BilledUnits {
input_tokens: billed_input_tokens,
output_tokens: billed_output_tokens,
..
} = billed_units.unwrap();
let Tokens {
input_tokens,
output_tokens,
} = tokens.unwrap();
assert_eq!(billed_input_tokens.unwrap(), 78.0);
assert_eq!(billed_output_tokens.unwrap(), 27.0);
assert_eq!(input_tokens.unwrap(), 1028.0);
assert_eq!(output_tokens.unwrap(), 63.0);
assert!(citations.is_empty());
assert_eq!(tool_calls.len(), 1);
let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
assert_eq!(name, "subtract");
assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
}
#[test]
fn test_convert_completion_message_to_message_and_back() {
let completion_message = completion::Message::User {
content: OneOrMany::one(completion::message::UserContent::Text(
completion::message::Text {
text: "Hello, world!".to_string(),
},
)),
};
let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
let _converted_back: Vec<completion::Message> = messages
.into_iter()
.map(|msg| msg.try_into().unwrap())
.collect::<Vec<_>>();
}
#[test]
fn test_convert_message_to_completion_message_and_back() {
let message = Message::User {
content: OneOrMany::one(UserContent::Text {
text: "Hello, world!".to_string(),
}),
};
let completion_message: completion::Message = message.clone().try_into().unwrap();
let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
}
}

View File

@ -0,0 +1,142 @@
use super::{client::ApiResponse, Client};
use crate::embeddings::{self, EmbeddingError};
use serde::Deserialize;
use serde_json::json;
#[derive(Deserialize)]
pub struct EmbeddingResponse {
#[serde(default)]
pub response_type: Option<String>,
pub id: String,
pub embeddings: Vec<Vec<f64>>,
pub texts: Vec<String>,
#[serde(default)]
pub meta: Option<Meta>,
}
#[derive(Deserialize)]
pub struct Meta {
pub api_version: ApiVersion,
pub billed_units: BilledUnits,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Deserialize)]
pub struct ApiVersion {
pub version: String,
#[serde(default)]
pub is_deprecated: Option<bool>,
#[serde(default)]
pub is_experimental: Option<bool>,
}
#[derive(Deserialize, Debug)]
pub struct BilledUnits {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default)]
pub search_units: u32,
#[serde(default)]
pub classifications: u32,
}
impl std::fmt::Display for BilledUnits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
self.input_tokens, self.output_tokens, self.search_units, self.classifications
)
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
pub input_type: String,
ndims: usize,
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 96;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/v1/embed")
.json(&json!({
"model": self.model,
"texts": documents,
"input_type": self.input_type,
}))
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
match response.meta {
Some(meta) => tracing::info!(target: "rig",
"Cohere embeddings billed units: {}",
meta.billed_units,
),
None => tracing::info!(target: "rig",
"Cohere embeddings billed units: n/a",
),
};
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(
format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)
.into(),
));
}
Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
input_type: input_type.to_string(),
ndims,
}
}
}

View File

@ -0,0 +1,55 @@
//! Cohere API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::cohere;
//!
//! let client = cohere::Client::new("YOUR_API_KEY");
//!
//! let command_r = client.completion_model(cohere::COMMAND_R);
//! ```
pub mod client;
pub mod completion;
pub mod embeddings;
pub use client::Client;
pub use client::{ApiErrorResponse, ApiResponse};
pub use completion::CompletionModel;
pub use embeddings::EmbeddingModel;
// ================================================================
// Cohere Completion Models
// ================================================================
/// `command-r-plus` completion model
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
/// `command-r` completion model
pub const COMMAND_R: &str = "command-r";
/// `command` completion model
pub const COMMAND: &str = "command";
/// `command-nightly` completion model
pub const COMMAND_NIGHTLY: &str = "command-nightly";
/// `command-light` completion model
pub const COMMAND_LIGHT: &str = "command-light";
/// `command-light-nightly` completion model
pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
// ================================================================
// Cohere Embedding Models
// ================================================================
/// `embed-english-v3.0` embedding model
pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
/// `embed-english-light-v3.0` embedding model
pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
/// `embed-multilingual-v3.0` embedding model
pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
/// `embed-multilingual-light-v3.0` embedding model
pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
/// `embed-english-v2.0` embedding model
pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
/// `embed-english-light-v2.0` embedding model
pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
/// `embed-multilingual-v2.0` embedding model
pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";