mirror of https://github.com/0xplaygrounds/rig
feat(provider): cohere-v2 (#350)
* fix(cohere): partial v2 * feat(cohere): partial * feat: cohere v2 impl * fix(cohere): working cohere impl
This commit is contained in:
parent
19c095122f
commit
599f8b46c3
|
@ -0,0 +1,125 @@
|
|||
use rig::{
|
||||
completion::{Prompt, ToolDefinition},
|
||||
providers,
|
||||
tool::Tool,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let client = providers::cohere::Client::from_env();
|
||||
let agent = client
|
||||
.agent("command-r")
|
||||
.preamble("You are a helpful assistant.")
|
||||
.build();
|
||||
|
||||
let answer = agent.prompt("Tell me a joke").await?;
|
||||
println!("Answer: {}", answer);
|
||||
|
||||
// Create agent with a single context prompt and two tools
|
||||
let calculator_agent = client
|
||||
.agent(providers::cohere::COMMAND_R)
|
||||
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
|
||||
.max_tokens(1024)
|
||||
.tool(Adder)
|
||||
.tool(Subtract)
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
println!("Calculate 2 - 5");
|
||||
println!(
|
||||
"Cohere Calculator Agent: {}",
|
||||
calculator_agent.prompt("Calculate 2 - 5").await?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Adder;
|
||||
impl Tool for Adder {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
println!("[tool-call] Adding {} and {}", args.x, args.y);
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
use std::env;
|
||||
|
||||
use rig::{
|
||||
completion::{Completion, Prompt},
|
||||
providers::cohere::Client as CohereClient,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create Cohere client
|
||||
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
|
||||
let cohere_client = CohereClient::new(&cohere_api_key);
|
||||
|
||||
let klimadao_agent = cohere_client
|
||||
.agent("command-r")
|
||||
.temperature(0.0)
|
||||
.additional_params(json!({
|
||||
"connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}]
|
||||
}))
|
||||
.build();
|
||||
|
||||
// Prompt the model and print the response
|
||||
// We use `prompt` to get a simple response from the model as a String
|
||||
let response = klimadao_agent.prompt("Tell me about BCT tokens?").await?;
|
||||
|
||||
println!("\n\nCoral: {:?}", response);
|
||||
|
||||
// Prompt the model and get the citations
|
||||
// We use `completion` to allow use to customize the request further and
|
||||
// get a more detailed response from the model.
|
||||
// Here the response is of type CompletionResponse<cohere::CompletionResponse>
|
||||
// which contains `choice` (Message or ToolCall) as well as `raw_response`,
|
||||
// the underlying providers' raw response.
|
||||
let response = klimadao_agent
|
||||
.completion("Tell me about BCT tokens?", vec![])
|
||||
.await?
|
||||
.additional_params(json!({
|
||||
"connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}]
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
println!(
|
||||
"\n\nCoral: {:?}\n\nCitations:\n{:?}",
|
||||
response.choice, response.raw_response.citations
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -15,6 +15,11 @@ struct Debater {
|
|||
|
||||
impl Debater {
|
||||
fn new(position_a: &str, position_b: &str) -> Self {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let openai_client =
|
||||
openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"));
|
||||
let cohere_client =
|
||||
|
|
|
@ -1,623 +0,0 @@
|
|||
//! Cohere API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::cohere;
|
||||
//!
|
||||
//! let client = cohere::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let command_r = client.completion_model(cohere::COMMAND_R);
|
||||
//! ```
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError},
|
||||
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
|
||||
extractor::ExtractorBuilder,
|
||||
json_utils, message, Embed, OneOrMany,
|
||||
};
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
// ================================================================
|
||||
// Main Cohere Client
|
||||
// ================================================================
|
||||
const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, COHERE_API_BASE_URL)
|
||||
}
|
||||
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Cohere reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
|
||||
let ndims = match model {
|
||||
EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
|
||||
EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
|
||||
EMBED_ENGLISH_V2 => 4096,
|
||||
EMBED_MULTILINGUAL_V2 => 768,
|
||||
_ => 0,
|
||||
};
|
||||
EmbeddingModel::new(self.clone(), model, input_type, ndims)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
|
||||
pub fn embedding_model_with_ndims(
|
||||
&self,
|
||||
model: &str,
|
||||
input_type: &str,
|
||||
ndims: usize,
|
||||
) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, input_type, ndims)
|
||||
}
|
||||
|
||||
pub fn embeddings<D: Embed>(
|
||||
&self,
|
||||
model: &str,
|
||||
input_type: &str,
|
||||
) -> EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
|
||||
}
|
||||
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Cohere Embedding API
|
||||
// ================================================================
|
||||
/// `embed-english-v3.0` embedding model
|
||||
pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
|
||||
/// `embed-english-light-v3.0` embedding model
|
||||
pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
|
||||
/// `embed-multilingual-v3.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
|
||||
/// `embed-multilingual-light-v3.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
|
||||
/// `embed-english-v2.0` embedding model
|
||||
pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
|
||||
/// `embed-english-light-v2.0` embedding model
|
||||
pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
|
||||
/// `embed-multilingual-v2.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
#[serde(default)]
|
||||
pub response_type: Option<String>,
|
||||
pub id: String,
|
||||
pub embeddings: Vec<Vec<f64>>,
|
||||
pub texts: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub meta: Option<Meta>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Meta {
|
||||
pub api_version: ApiVersion,
|
||||
pub billed_units: BilledUnits,
|
||||
#[serde(default)]
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ApiVersion {
|
||||
pub version: String,
|
||||
#[serde(default)]
|
||||
pub is_deprecated: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub is_experimental: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct BilledUnits {
|
||||
#[serde(default)]
|
||||
pub input_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub output_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub search_units: u32,
|
||||
#[serde(default)]
|
||||
pub classifications: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BilledUnits {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
|
||||
self.input_tokens, self.output_tokens, self.search_units, self.classifications
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
pub input_type: String,
|
||||
ndims: usize,
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 96;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
self.ndims
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String>,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let documents = documents.into_iter().collect::<Vec<_>>();
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v1/embed")
|
||||
.json(&json!({
|
||||
"model": self.model,
|
||||
"texts": documents,
|
||||
"input_type": self.input_type,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
|
||||
ApiResponse::Ok(response) => {
|
||||
match response.meta {
|
||||
Some(meta) => tracing::info!(target: "rig",
|
||||
"Cohere embeddings billed units: {}",
|
||||
meta.billed_units,
|
||||
),
|
||||
None => tracing::info!(target: "rig",
|
||||
"Cohere embeddings billed units: n/a",
|
||||
),
|
||||
};
|
||||
|
||||
if response.embeddings.len() != documents.len() {
|
||||
return Err(EmbeddingError::DocumentError(
|
||||
format!(
|
||||
"Expected {} embeddings, got {}",
|
||||
documents.len(),
|
||||
response.embeddings.len()
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(response
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.zip(documents.into_iter())
|
||||
.map(|(embedding, document)| embeddings::Embedding {
|
||||
document,
|
||||
vec: embedding,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
|
||||
}
|
||||
} else {
|
||||
Err(EmbeddingError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
input_type: input_type.to_string(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Cohere Completion API
|
||||
// ================================================================
|
||||
/// `command-r-plus` completion model
|
||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
||||
/// `command-r` completion model
|
||||
pub const COMMAND_R: &str = "command-r";
|
||||
/// `command` completion model
|
||||
pub const COMMAND: &str = "command";
|
||||
/// `command-nightly` completion model
|
||||
pub const COMMAND_NIGHTLY: &str = "command-nightly";
|
||||
/// `command-light` completion model
|
||||
pub const COMMAND_LIGHT: &str = "command-light";
|
||||
/// `command-light-nightly` completion model
|
||||
pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub text: String,
|
||||
pub generation_id: String,
|
||||
#[serde(default)]
|
||||
pub citations: Vec<Citation>,
|
||||
#[serde(default)]
|
||||
pub documents: Vec<Document>,
|
||||
#[serde(default)]
|
||||
pub is_search_required: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub search_queries: Vec<SearchQuery>,
|
||||
#[serde(default)]
|
||||
pub search_results: Vec<SearchResult>,
|
||||
pub finish_reason: String,
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
#[serde(default)]
|
||||
pub chat_history: Vec<ChatHistory>,
|
||||
}
|
||||
|
||||
impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
fn from(response: CompletionResponse) -> Self {
|
||||
let CompletionResponse {
|
||||
text, tool_calls, ..
|
||||
} = &response;
|
||||
|
||||
let model_response = if !tool_calls.is_empty() {
|
||||
tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| {
|
||||
completion::AssistantContent::tool_call(
|
||||
tool_call.name.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.parameters.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
vec![completion::AssistantContent::text(text.clone())]
|
||||
};
|
||||
|
||||
completion::CompletionResponse {
|
||||
choice: OneOrMany::many(model_response).expect("There is atleast one content"),
|
||||
raw_response: response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Citation {
|
||||
pub start: u32,
|
||||
pub end: u32,
|
||||
pub text: String,
|
||||
pub document_ids: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Document {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub additional_prop: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchQuery {
|
||||
pub text: String,
|
||||
pub generation_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub search_query: SearchQuery,
|
||||
pub connector: Connector,
|
||||
pub document_ids: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub error_message: Option<String>,
|
||||
#[serde(default)]
|
||||
pub continue_on_failure: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Connector {
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatHistory {
|
||||
pub role: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Parameter {
|
||||
pub description: String,
|
||||
pub r#type: String,
|
||||
pub required: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameter_definitions: HashMap<String, Parameter>,
|
||||
}
|
||||
|
||||
impl From<completion::ToolDefinition> for ToolDefinition {
|
||||
fn from(tool: completion::ToolDefinition) -> Self {
|
||||
fn convert_type(r#type: &serde_json::Value) -> String {
|
||||
fn convert_type_str(r#type: &str) -> String {
|
||||
match r#type {
|
||||
"string" => "string".to_owned(),
|
||||
"number" => "number".to_owned(),
|
||||
"integer" => "integer".to_owned(),
|
||||
"boolean" => "boolean".to_owned(),
|
||||
"array" => "array".to_owned(),
|
||||
"object" => "object".to_owned(),
|
||||
_ => "string".to_owned(),
|
||||
}
|
||||
}
|
||||
match r#type {
|
||||
serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
|
||||
serde_json::Value::Array(types) => convert_type_str(
|
||||
types
|
||||
.iter()
|
||||
.find(|t| t.as_str() != Some("null"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("string"),
|
||||
),
|
||||
_ => "string".to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
let maybe_required = tool
|
||||
.parameters
|
||||
.get("required")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|required| {
|
||||
required
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameter_definitions: tool
|
||||
.parameters
|
||||
.get("properties")
|
||||
.expect("Tool properties should exist")
|
||||
.as_object()
|
||||
.expect("Tool properties should be an object")
|
||||
.iter()
|
||||
.map(|(argname, argdef)| {
|
||||
(
|
||||
argname.clone(),
|
||||
Parameter {
|
||||
description: argdef
|
||||
.get("description")
|
||||
.expect("Argument description should exist")
|
||||
.as_str()
|
||||
.expect("Argument description should be a string")
|
||||
.to_string(),
|
||||
r#type: convert_type(
|
||||
argdef.get("type").expect("Argument type should exist"),
|
||||
),
|
||||
required: maybe_required.contains(&argname.as_str()),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<HashMap<_, _>>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(tag = "role", rename_all = "UPPERCASE")]
|
||||
pub enum Message {
|
||||
User {
|
||||
message: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
|
||||
Chatbot {
|
||||
message: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
|
||||
Tool {
|
||||
tool_results: Vec<ToolResult>,
|
||||
},
|
||||
|
||||
/// According to the documentation, this message type should not be used
|
||||
System {
|
||||
content: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct ToolResult {
|
||||
pub call: ToolCall,
|
||||
pub outputs: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl TryFrom<message::Message> for Vec<Message> {
|
||||
type Error = message::MessageError;
|
||||
|
||||
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
|
||||
match message {
|
||||
message::Message::User { content } => content
|
||||
.into_iter()
|
||||
.map(|content| {
|
||||
Ok(Message::User {
|
||||
message: match content {
|
||||
message::UserContent::Text(message::Text { text }) => text,
|
||||
_ => {
|
||||
return Err(message::MessageError::ConversionError(
|
||||
"Only text content is supported by Cohere".to_owned(),
|
||||
))
|
||||
}
|
||||
},
|
||||
tool_calls: vec![],
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>(),
|
||||
_ => Err(message::MessageError::ConversionError(
|
||||
"Only user messages are supported by Cohere".to_owned(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let chat_history = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(Vec::<Message>::try_from)
|
||||
.collect::<Result<Vec<Vec<_>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let message = match completion_request.prompt {
|
||||
message::Message::User { content } => Ok(content
|
||||
.into_iter()
|
||||
.map(|content| match content {
|
||||
message::UserContent::Text(message::Text { text }) => Ok(text),
|
||||
_ => Err(CompletionError::RequestError(
|
||||
"Only text content is supported by Cohere".into(),
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.join("\n")),
|
||||
|
||||
_ => Err(CompletionError::RequestError(
|
||||
"Only user messages are supported by Cohere".into(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"preamble": completion_request.preamble,
|
||||
"message": message,
|
||||
"documents": completion_request.documents,
|
||||
"chat_history": chat_history,
|
||||
"temperature": completion_request.temperature,
|
||||
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v1/chat")
|
||||
.json(
|
||||
&if let Some(ref params) = completion_request.additional_params {
|
||||
json_utils::merge(request.clone(), params.clone())
|
||||
} else {
|
||||
request.clone()
|
||||
},
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<CompletionResponse>>().await? {
|
||||
ApiResponse::Ok(completion) => Ok(completion.into()),
|
||||
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
|
||||
}
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
use crate::{
|
||||
agent::AgentBuilder, embeddings::EmbeddingsBuilder, extractor::ExtractorBuilder, Embed,
|
||||
};
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{CompletionModel, EmbeddingModel};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Main Cohere Client
|
||||
// ================================================================
|
||||
const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, COHERE_API_BASE_URL)
|
||||
}
|
||||
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Cohere reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
|
||||
let ndims = match model {
|
||||
super::EMBED_ENGLISH_V3
|
||||
| super::EMBED_MULTILINGUAL_V3
|
||||
| super::EMBED_ENGLISH_LIGHT_V2 => 1024,
|
||||
super::EMBED_ENGLISH_LIGHT_V3 | super::EMBED_MULTILINGUAL_LIGHT_V3 => 384,
|
||||
super::EMBED_ENGLISH_V2 => 4096,
|
||||
super::EMBED_MULTILINGUAL_V2 => 768,
|
||||
_ => 0,
|
||||
};
|
||||
EmbeddingModel::new(self.clone(), model, input_type, ndims)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
|
||||
pub fn embedding_model_with_ndims(
|
||||
&self,
|
||||
model: &str,
|
||||
input_type: &str,
|
||||
ndims: usize,
|
||||
) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, input_type, ndims)
|
||||
}
|
||||
|
||||
pub fn embeddings<D: Embed>(
|
||||
&self,
|
||||
model: &str,
|
||||
input_type: &str,
|
||||
) -> EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
|
||||
}
|
||||
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,611 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
completion::{self, CompletionError},
|
||||
json_utils, message, OneOrMany,
|
||||
};
|
||||
|
||||
use super::client::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub finish_reason: FinishReason,
|
||||
message: Message,
|
||||
#[serde(default)]
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl CompletionResponse {
|
||||
/// Return that parts of the response for assistant messages w/o dealing with the other variants
|
||||
pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
|
||||
let Message::Assistant {
|
||||
content,
|
||||
citations,
|
||||
tool_calls,
|
||||
..
|
||||
} = self.message.clone()
|
||||
else {
|
||||
unreachable!("Completion responses will only return an assistant message")
|
||||
};
|
||||
|
||||
(content, citations, tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum FinishReason {
|
||||
MaxTokens,
|
||||
StopSequence,
|
||||
Complete,
|
||||
Error,
|
||||
ToolCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Usage {
|
||||
#[serde(default)]
|
||||
pub billed_units: Option<BilledUnits>,
|
||||
#[serde(default)]
|
||||
pub tokens: Option<Tokens>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct BilledUnits {
|
||||
#[serde(default)]
|
||||
pub output_tokens: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub classifications: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub search_units: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub input_tokens: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Tokens {
|
||||
#[serde(default)]
|
||||
pub input_tokens: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub output_tokens: Option<f64>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
|
||||
let (content, _, tool_calls) = response.message();
|
||||
|
||||
let model_response = if !tool_calls.is_empty() {
|
||||
OneOrMany::many(
|
||||
tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|tool_call| {
|
||||
let ToolCallFunction { name, arguments } = tool_call.function?;
|
||||
let id = tool_call.id.unwrap_or_else(|| name.clone());
|
||||
|
||||
Some(completion::AssistantContent::tool_call(id, name, arguments))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.expect("We have atleast 1 tool call in this if block")
|
||||
} else {
|
||||
OneOrMany::many(content.into_iter().map(|content| match content {
|
||||
AssistantContent::Text { text } => completion::AssistantContent::text(text),
|
||||
}))
|
||||
.map_err(|_| {
|
||||
CompletionError::ResponseError(
|
||||
"Response contained no message or tool call (empty)".to_owned(),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice: OneOrMany::many(model_response).expect("There is atleast one content"),
|
||||
raw_response: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct Document {
|
||||
pub id: String,
|
||||
pub data: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl From<completion::Document> for Document {
|
||||
fn from(document: completion::Document) -> Self {
|
||||
let mut data: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
// We use `.into()` here explicitly since the `document.additional_props` type will likely
|
||||
// evolve into `serde_json::Value` in the future.
|
||||
document
|
||||
.additional_props
|
||||
.into_iter()
|
||||
.for_each(|(key, value)| {
|
||||
data.insert(key, value.into());
|
||||
});
|
||||
|
||||
data.insert("text".to_string(), document.text.into());
|
||||
|
||||
Self {
|
||||
id: document.id,
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct ToolCall {
|
||||
#[serde(default)]
|
||||
pub id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub r#type: Option<ToolType>,
|
||||
#[serde(default)]
|
||||
pub function: Option<ToolCallFunction>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct ToolCallFunction {
|
||||
pub name: String,
|
||||
#[serde(with = "json_utils::stringified_json")]
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolType {
|
||||
#[default]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct Tool {
|
||||
pub r#type: ToolType,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
impl From<completion::ToolDefinition> for Tool {
|
||||
fn from(tool: completion::ToolDefinition) -> Self {
|
||||
Self {
|
||||
r#type: ToolType::default(),
|
||||
function: Function {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum Message {
|
||||
User {
|
||||
content: OneOrMany<UserContent>,
|
||||
},
|
||||
|
||||
Assistant {
|
||||
#[serde(default)]
|
||||
content: Vec<AssistantContent>,
|
||||
#[serde(default)]
|
||||
citations: Vec<Citation>,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
#[serde(default)]
|
||||
tool_plan: Option<String>,
|
||||
},
|
||||
|
||||
Tool {
|
||||
content: OneOrMany<ToolResultContent>,
|
||||
tool_call_id: String,
|
||||
},
|
||||
|
||||
System {
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum UserContent {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum AssistantContent {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub enum ToolResultContent {
|
||||
Text { text: String },
|
||||
Document { document: Document },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Citation {
|
||||
#[serde(default)]
|
||||
pub start: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub end: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub text: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub citation_type: Option<CitationType>,
|
||||
#[serde(default)]
|
||||
pub sources: Vec<Source>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum Source {
|
||||
Document {
|
||||
id: Option<String>,
|
||||
document: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
},
|
||||
Tool {
|
||||
id: Option<String>,
|
||||
tool_output: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum CitationType {
|
||||
TextContent,
|
||||
Plan,
|
||||
}
|
||||
|
||||
impl TryFrom<message::Message> for Vec<Message> {
|
||||
type Error = message::MessageError;
|
||||
|
||||
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
|
||||
Ok(match message {
|
||||
message::Message::User { content } => content
|
||||
.into_iter()
|
||||
.map(|content| match content {
|
||||
message::UserContent::Text(message::Text { text }) => Ok(Message::User {
|
||||
content: OneOrMany::one(UserContent::Text { text }),
|
||||
}),
|
||||
message::UserContent::ToolResult(message::ToolResult { id, content }) => {
|
||||
Ok(Message::Tool {
|
||||
tool_call_id: id,
|
||||
content: content.try_map(|content| match content {
|
||||
message::ToolResultContent::Text(text) => {
|
||||
Ok(ToolResultContent::Text { text: text.text })
|
||||
}
|
||||
_ => Err(message::MessageError::ConversionError(
|
||||
"Only text tool result content is supported by Cohere"
|
||||
.to_owned(),
|
||||
)),
|
||||
})?,
|
||||
})
|
||||
}
|
||||
_ => Err(message::MessageError::ConversionError(
|
||||
"Only text content is supported by Cohere".to_owned(),
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
message::Message::Assistant { content } => {
|
||||
let mut text_content = vec![];
|
||||
let mut tool_calls = vec![];
|
||||
content.into_iter().for_each(|content| match content {
|
||||
message::AssistantContent::Text(message::Text { text }) => {
|
||||
text_content.push(AssistantContent::Text { text });
|
||||
}
|
||||
message::AssistantContent::ToolCall(message::ToolCall {
|
||||
id,
|
||||
function: message::ToolFunction { name, arguments },
|
||||
}) => {
|
||||
tool_calls.push(ToolCall {
|
||||
id: Some(id),
|
||||
r#type: Some(ToolType::Function),
|
||||
function: Some(ToolCallFunction {
|
||||
name,
|
||||
arguments: serde_json::to_value(arguments).unwrap_or_default(),
|
||||
}),
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
vec![Message::Assistant {
|
||||
content: text_content,
|
||||
citations: vec![],
|
||||
tool_calls,
|
||||
tool_plan: None,
|
||||
}]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Message> for message::Message {
|
||||
type Error = message::MessageError;
|
||||
|
||||
fn try_from(message: Message) -> Result<Self, Self::Error> {
|
||||
match message {
|
||||
Message::User { content } => Ok(message::Message::User {
|
||||
content: content.map(|content| match content {
|
||||
UserContent::Text { text } => {
|
||||
message::UserContent::Text(message::Text { text })
|
||||
}
|
||||
UserContent::ImageUrl { image_url } => message::UserContent::image(
|
||||
image_url.url,
|
||||
Some(message::ContentFormat::String),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
}),
|
||||
}),
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => {
|
||||
let mut content = content
|
||||
.into_iter()
|
||||
.map(|content| match content {
|
||||
AssistantContent::Text { text } => message::AssistantContent::text(text),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
content.extend(tool_calls.into_iter().filter_map(|tool_call| {
|
||||
let ToolCallFunction { name, arguments } = tool_call.function?;
|
||||
|
||||
Some(message::AssistantContent::tool_call(
|
||||
tool_call.id.unwrap_or_else(|| name.clone()),
|
||||
name,
|
||||
arguments,
|
||||
))
|
||||
}));
|
||||
|
||||
let content = OneOrMany::many(content).map_err(|_| {
|
||||
message::MessageError::ConversionError(
|
||||
"Expected either text content or tool calls".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(message::Message::Assistant { content })
|
||||
}
|
||||
Message::Tool {
|
||||
content,
|
||||
tool_call_id,
|
||||
} => {
|
||||
let content = content.try_map(|content| {
|
||||
Ok(match content {
|
||||
ToolResultContent::Text { text } => message::ToolResultContent::text(text),
|
||||
ToolResultContent::Document { document } => {
|
||||
message::ToolResultContent::text(
|
||||
serde_json::to_string(&document.data).map_err(|e| {
|
||||
message::MessageError::ConversionError(
|
||||
format!("Failed to convert tool result document content into text: {}", e),
|
||||
)
|
||||
})?,
|
||||
)
|
||||
}
|
||||
})
|
||||
})?;
|
||||
|
||||
Ok(message::Message::User {
|
||||
content: OneOrMany::one(message::UserContent::tool_result(
|
||||
tool_call_id,
|
||||
content,
|
||||
)),
|
||||
})
|
||||
}
|
||||
Message::System { content } => Ok(message::Message::user(content)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let prompt = completion_request.prompt_with_context();
|
||||
|
||||
let mut messages: Vec<message::Message> =
|
||||
if let Some(preamble) = completion_request.preamble {
|
||||
vec![preamble.into()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
messages.extend(completion_request.chat_history);
|
||||
messages.push(prompt);
|
||||
|
||||
let messages: Vec<Message> = messages
|
||||
.into_iter()
|
||||
.map(|msg| msg.try_into())
|
||||
.collect::<Result<Vec<Vec<_>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"documents": completion_request.documents,
|
||||
"temperature": completion_request.temperature,
|
||||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||
});
|
||||
|
||||
tracing::debug!(
|
||||
"Cohere request: {}",
|
||||
serde_json::to_string_pretty(&request)?
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v2/chat")
|
||||
.json(
|
||||
&if let Some(ref params) = completion_request.additional_params {
|
||||
json_utils::merge(request.clone(), params.clone())
|
||||
} else {
|
||||
request.clone()
|
||||
},
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let text_response = response.text().await?;
|
||||
tracing::debug!("Cohere response text: {}", text_response);
|
||||
|
||||
let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
|
||||
let completion: completion::CompletionResponse<CompletionResponse> =
|
||||
json_response.try_into()?;
|
||||
Ok(completion)
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_path_to_error::deserialize;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_completion_response() {
|
||||
let json_data = r#"
|
||||
{
|
||||
"id": "abc123",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "subtract_sm6ps6fb6y9f",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "subtract",
|
||||
"arguments": "{\"x\":5,\"y\":2}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "TOOL_CALL",
|
||||
"usage": {
|
||||
"billed_units": {
|
||||
"input_tokens": 78,
|
||||
"output_tokens": 27
|
||||
},
|
||||
"tokens": {
|
||||
"input_tokens": 1028,
|
||||
"output_tokens": 63
|
||||
}
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let mut deserializer = serde_json::Deserializer::from_str(json_data);
|
||||
let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
|
||||
|
||||
let response = result.unwrap();
|
||||
let (_, citations, tool_calls) = response.message();
|
||||
let CompletionResponse {
|
||||
id,
|
||||
finish_reason,
|
||||
usage,
|
||||
..
|
||||
} = response;
|
||||
|
||||
assert_eq!(id, "abc123");
|
||||
assert_eq!(finish_reason, FinishReason::ToolCall);
|
||||
|
||||
let Usage {
|
||||
billed_units,
|
||||
tokens,
|
||||
} = usage.unwrap();
|
||||
let BilledUnits {
|
||||
input_tokens: billed_input_tokens,
|
||||
output_tokens: billed_output_tokens,
|
||||
..
|
||||
} = billed_units.unwrap();
|
||||
let Tokens {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
} = tokens.unwrap();
|
||||
|
||||
assert_eq!(billed_input_tokens.unwrap(), 78.0);
|
||||
assert_eq!(billed_output_tokens.unwrap(), 27.0);
|
||||
assert_eq!(input_tokens.unwrap(), 1028.0);
|
||||
assert_eq!(output_tokens.unwrap(), 63.0);
|
||||
|
||||
assert!(citations.is_empty());
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
|
||||
let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
|
||||
|
||||
assert_eq!(name, "subtract");
|
||||
assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_completion_message_to_message_and_back() {
|
||||
let completion_message = completion::Message::User {
|
||||
content: OneOrMany::one(completion::message::UserContent::Text(
|
||||
completion::message::Text {
|
||||
text: "Hello, world!".to_string(),
|
||||
},
|
||||
)),
|
||||
};
|
||||
|
||||
let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
|
||||
let _converted_back: Vec<completion::Message> = messages
|
||||
.into_iter()
|
||||
.map(|msg| msg.try_into().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_message_to_completion_message_and_back() {
|
||||
let message = Message::User {
|
||||
content: OneOrMany::one(UserContent::Text {
|
||||
text: "Hello, world!".to_string(),
|
||||
}),
|
||||
};
|
||||
|
||||
let completion_message: completion::Message = message.clone().try_into().unwrap();
|
||||
let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,142 @@
|
|||
use super::{client::ApiResponse, Client};
|
||||
|
||||
use crate::embeddings::{self, EmbeddingError};
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
#[serde(default)]
|
||||
pub response_type: Option<String>,
|
||||
pub id: String,
|
||||
pub embeddings: Vec<Vec<f64>>,
|
||||
pub texts: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub meta: Option<Meta>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Meta {
|
||||
pub api_version: ApiVersion,
|
||||
pub billed_units: BilledUnits,
|
||||
#[serde(default)]
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ApiVersion {
|
||||
pub version: String,
|
||||
#[serde(default)]
|
||||
pub is_deprecated: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub is_experimental: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct BilledUnits {
|
||||
#[serde(default)]
|
||||
pub input_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub output_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub search_units: u32,
|
||||
#[serde(default)]
|
||||
pub classifications: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BilledUnits {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
|
||||
self.input_tokens, self.output_tokens, self.search_units, self.classifications
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
pub input_type: String,
|
||||
ndims: usize,
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 96;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
self.ndims
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String>,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let documents = documents.into_iter().collect::<Vec<_>>();
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/v1/embed")
|
||||
.json(&json!({
|
||||
"model": self.model,
|
||||
"texts": documents,
|
||||
"input_type": self.input_type,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
|
||||
ApiResponse::Ok(response) => {
|
||||
match response.meta {
|
||||
Some(meta) => tracing::info!(target: "rig",
|
||||
"Cohere embeddings billed units: {}",
|
||||
meta.billed_units,
|
||||
),
|
||||
None => tracing::info!(target: "rig",
|
||||
"Cohere embeddings billed units: n/a",
|
||||
),
|
||||
};
|
||||
|
||||
if response.embeddings.len() != documents.len() {
|
||||
return Err(EmbeddingError::DocumentError(
|
||||
format!(
|
||||
"Expected {} embeddings, got {}",
|
||||
documents.len(),
|
||||
response.embeddings.len()
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(response
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.zip(documents.into_iter())
|
||||
.map(|(embedding, document)| embeddings::Embedding {
|
||||
document,
|
||||
vec: embedding,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
|
||||
}
|
||||
} else {
|
||||
Err(EmbeddingError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
input_type: input_type.to_string(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
//! Cohere API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::cohere;
|
||||
//!
|
||||
//! let client = cohere::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let command_r = client.completion_model(cohere::COMMAND_R);
|
||||
//! ```
|
||||
|
||||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod embeddings;
|
||||
|
||||
pub use client::Client;
|
||||
pub use client::{ApiErrorResponse, ApiResponse};
|
||||
pub use completion::CompletionModel;
|
||||
pub use embeddings::EmbeddingModel;
|
||||
|
||||
// ================================================================
|
||||
// Cohere Completion Models
|
||||
// ================================================================
|
||||
|
||||
/// `command-r-plus` completion model
|
||||
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
|
||||
/// `command-r` completion model
|
||||
pub const COMMAND_R: &str = "command-r";
|
||||
/// `command` completion model
|
||||
pub const COMMAND: &str = "command";
|
||||
/// `command-nightly` completion model
|
||||
pub const COMMAND_NIGHTLY: &str = "command-nightly";
|
||||
/// `command-light` completion model
|
||||
pub const COMMAND_LIGHT: &str = "command-light";
|
||||
/// `command-light-nightly` completion model
|
||||
pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
|
||||
|
||||
// ================================================================
|
||||
// Cohere Embedding Models
|
||||
// ================================================================
|
||||
|
||||
/// `embed-english-v3.0` embedding model
|
||||
pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
|
||||
/// `embed-english-light-v3.0` embedding model
|
||||
pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
|
||||
/// `embed-multilingual-v3.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
|
||||
/// `embed-multilingual-light-v3.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
|
||||
/// `embed-english-v2.0` embedding model
|
||||
pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
|
||||
/// `embed-english-light-v2.0` embedding model
|
||||
pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
|
||||
/// `embed-multilingual-v2.0` embedding model
|
||||
pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
|
Loading…
Reference in New Issue