refactor: Split Chat and Prompt traits + other stuff

This commit is contained in:
Christophe 2024-06-08 03:58:00 -04:00
parent 7585f0e742
commit c8a4ae39f7
No known key found for this signature in database
GPG Key ID: 1BB7A015FC6BC135
25 changed files with 612 additions and 143 deletions

View File

@ -15,16 +15,8 @@ async fn main() -> Result<(), anyhow::Error> {
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();
// let client = providers::cohere::Client::new(
// &env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"),
// );
// let comedian_agent = client
// .agent("command-r")
// .preamble("You are a comedian here to entertain the user using humour and jokes.")
// .build();
// Prompt the agent and print the response
let response = comedian_agent.chat("Entertain me!", vec![]).await?;
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);
Ok(())

View File

@ -20,9 +20,7 @@ async fn main() -> Result<(), anyhow::Error> {
.build();
// Prompt the agent and print the response
let response = agent
.chat("What does \"glarb-glarb\" mean?", vec![])
.await?;
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;
println!("{}", response);

View File

@ -28,10 +28,10 @@ impl Tool for Adder {
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "add",
"description": "Add x and y together",
"parameters": {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
@ -43,9 +43,8 @@ impl Tool for Adder {
"description": "The second number to add"
}
}
}
}))
.expect("Tool Definition")
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
@ -120,15 +119,11 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Calculate 2 - 5");
println!(
"GPT-4: {}",
gpt4_calculator_agent
.chat("Calculate 2 - 5", vec![])
.await?
gpt4_calculator_agent.prompt("Calculate 2 - 5").await?
);
println!(
"Coral: {}",
coral_calculator_agent
.chat("Calculate 2 - 5", vec![])
.await?
coral_calculator_agent.prompt("Calculate 2 - 5").await?
);
Ok(())

View File

@ -22,9 +22,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Prompt the model and print the response
// We use `prompt` to get a simple response from the model as a String
let response = klimadao_agent
.chat("Tell me about BCT tokens?", vec![])
.await?;
let response = klimadao_agent.prompt("Tell me about BCT tokens?").await?;
println!("\n\nCoral: {:?}", response);

View File

@ -3,7 +3,7 @@ use std::env;
use anyhow::Result;
use rig::{
agent::Agent,
completion::{Message, Prompt},
completion::{Chat, Message},
providers::{cohere, openai},
};

View File

@ -1,6 +1,4 @@
use std::env;
use rig::providers::{cohere::Client as CohereClient, openai::Client as OpenAIClient};
use rig::providers::openai;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -18,8 +16,7 @@ struct Person {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = OpenAIClient::new(&openai_api_key);
let openai_client = openai::Client::from_env();
// Create extractor
let data_extractor = openai_client.extractor::<Person>("gpt-4").build();
@ -30,18 +27,5 @@ async fn main() -> Result<(), anyhow::Error> {
println!("GPT-4: {}", serde_json::to_string_pretty(&person).unwrap());
// Create Cohere client
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = CohereClient::new(&cohere_api_key);
// Create extractor
let data_extractor = cohere_client.extractor::<Person>("command-r").build();
let person = data_extractor
.extract("Hello my name is John Doe! I am a software engineer.")
.await?;
println!("Coral: {}", serde_json::to_string_pretty(&person).unwrap());
Ok(())
}

View File

@ -3,7 +3,7 @@ use std::env;
use rig::{
agent::{Agent, AgentBuilder},
cli_chatbot::cli_chatbot,
completion::{CompletionModel, Message, Prompt, PromptError},
completion::{Chat, CompletionModel, Message, PromptError},
model::{Model, ModelBuilder},
providers::openai::Client as OpenAIClient,
};
@ -34,7 +34,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
}
}
impl<M: CompletionModel> Prompt for EnglishTranslator<M> {
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
// Translate the prompt using the translator agent
let translated_prompt = self

View File

@ -3,7 +3,7 @@ use std::env;
use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
providers::openai::Client,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
};
@ -13,7 +13,7 @@ async fn main() -> Result<(), anyhow::Error> {
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let embedding_model = openai_client.embedding_model("text-embedding-ada-002");
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
// Create vector store, compute embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();
@ -39,9 +39,7 @@ async fn main() -> Result<(), anyhow::Error> {
.build();
// Prompt the agent and print the response
let response = rag_agent
.chat("What does \"glarb-glarb\" mean?", vec![])
.await?;
let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await?;
println!("{}", response);

View File

@ -178,7 +178,7 @@ async fn main() -> Result<(), anyhow::Error> {
.build();
// Prompt the agent and print the response
let response = calculator_rag.chat("Calculate 3 - 7", vec![]).await?;
let response = calculator_rag.prompt("Calculate 3 - 7").await?;
println!("{}", response);
Ok(())

View File

@ -0,0 +1,35 @@
use rig::providers::openai;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// An enum representing the sentiment of a document
enum Sentiment {
Positive,
Negative,
Neutral,
}
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct DocumentSentiment {
/// The sentiment of the document
sentiment: Sentiment,
}
#[tokio::main]
async fn main() {
// Create OpenAI client
let openai_client = openai::Client::from_env();
// Create extractor
let data_extractor = openai_client
.extractor::<DocumentSentiment>("gpt-4")
.build();
let sentiment = data_extractor
.extract("I am happy")
.await
.expect("Failed to extract sentiment");
println!("GPT-4: {:?}", sentiment);
}

View File

@ -1,28 +1,17 @@
use std::env;
use rig::{
completion::Prompt,
providers::{cohere, openai},
};
use rig::{completion::Prompt, providers::openai};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
async fn main() {
// Create OpenAI client and model
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = openai::Client::new(&openai_api_key);
let openai_client = openai::Client::from_env();
let gpt4 = openai_client.model("gpt-4").temperature(0.0).build();
let gpt4 = openai_client.model("gpt-4").build();
// Create Cohere client and model
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = cohere::Client::new(&cohere_api_key);
// Prompt the model and print its response
let response = gpt4
.prompt("Who are you?")
.await
.expect("Failed to prompt GPT-4");
let command_r = cohere_client.model("command-r").temperature(0.0).build();
// Prompt the models and print their response
println!("Question: Who are you?");
println!("GPT-4: {:?}", gpt4.chat("Who are you?", vec![]).await?);
println!("Coral: {:?}", command_r.chat("Who are you?", vec![]).await?);
Ok(())
println!("GPT-4: {response}");
}

View File

@ -1,20 +1,71 @@
//! This module contains the implementation of the `Agent` struct and its builder.
//!
//! The `Agent` struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
//! by providing prompts and chat history.
//!
//! The `AgentBuilder` struct provides a builder pattern for creating instances of the `Agent` struct.
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
//! before building the agent.
//!
//! Example usage:
//!
//! ```rust
//! use rig::{completion::Prompt, providers::openai};
//!
//! let openai_client = openai::Client::from_env();
//!
//! // Configure the agent
//! let agent = client.agent("gpt-4o")
//! .preamble("System prompt")
//! .context("Context document 1")
//! .context("Context document 2")
//! .tool(tool1)
//! .tool(tool2)
//! .temperature(0.8)
//! .additional_params(json!({"foo": "bar"}))
//! .build();
//!
//! // Use the agent for completions and prompts
//! let completion_req_builder = agent.completion("Prompt", chat_history).await;
//! let chat_response = agent.chat("Prompt", chat_history).await;
//! ```
//!
//! For more information on how to use the `Agent` struct and its builder, refer to the documentation of the respective structs and methods.
use std::collections::HashMap;
use futures::{stream, StreamExt};
use crate::{
completion::{
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
Document, Message, ModelChoice, Prompt, PromptError,
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet},
};
/// Struct reprensenting an LLM agent. An agent is an LLM model
/// combined with static context (i.e.: always inserted at the top
/// of the chat history before any use prompts) and static tools.
/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
///
/// # Example
/// ```
/// use rig::{completion::Prompt, providers::openai};
///
/// let openai_client = openai::Client::from_env();
///
/// let comedian_agent = client
/// .agent("gpt-4o")
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
/// .temperature(0.9)
/// .build();
///
/// let response = comedian_agent.prompt("Entertain me!")
/// .await
/// .expect("Failed to prompt GPT-4");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
model: M,
/// System prompt
preamble: String,
@ -54,7 +105,7 @@ impl<M: CompletionModel> Agent<M> {
additional_props: HashMap::new(),
})
.collect(),
tools: ToolSet::new(static_tools),
tools: ToolSet::from_tools(static_tools),
static_tools: static_tools_ids,
temperature,
additional_params,
@ -73,7 +124,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::error!(target: "ai", "Agent static tool {} not found", toolname);
tracing::error!(target: "rig", "Agent static tool {} not found", toolname);
None
}
})
@ -93,6 +144,12 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
}
impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
@ -107,6 +164,27 @@ impl<M: CompletionModel> Prompt for Agent<M> {
}
}
/// A builder for creating an agent
///
/// # Example
/// ```
/// use rig::{providers::openai, agent::AgentBuilder};
///
/// let openai_client = openai::Client::from_env();
///
/// let gpt4 = openai_client.completion_model("gpt-4");
///
/// // Configure the agent
/// let agent = AgentBuilder::new(model)
/// .preamble("System prompt")
/// .context("Context document 1")
/// .context("Context document 2")
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
model: M,
preamble: Option<String>,

View File

@ -1,8 +1,8 @@
use std::io::{self, Write};
use crate::completion::{Message, Prompt, PromptError};
use crate::completion::{Chat, Message, PromptError};
pub async fn cli_chatbot(chatbot: impl Prompt) -> Result<(), PromptError> {
pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
let stdin = io::stdin();
let mut stdout = io::stdout();
let mut chat_log = vec![];

View File

@ -1,3 +1,53 @@
//! This module contains the implementation of the completion functionality for the LLM (Large Language
//! Model) chat interface. It provides traits, structs, and enums for generating completion requests,
//! handling completion responses, and defining completion models.
//!
//! The main traits defined in this module are:
//! - `Prompt`: Defines a high-level LLM chat interface for prompting and receiving responses.
//! - `Completion`: Defines a low-level LLM completion interface for generating completion requests.
//! - `CompletionModel`: Defines a completion model that can be used to generate completion responses.
//!
//! The module also provides various structs and enums for representing generic completion requests,
//! responses, and errors.
//!
//! Example Usage:
//!
//! ```rust
//! use rig::providers::openai::{Client, self};
//! use rig::completion::*;
//!
//! // Initialize the OpenAI client and a completion model
//! let openai = Client::new("your-openai-api-key");
//!
//! let model = openai.model(openai::GPT_4).build();
//!
//!
//! // Create the completion request
//! let builder = model.completion_request("Who are you?");
//! .preamble(
//! "You are Marvin, an extremely smart but depressed robot who is nonetheless helpful towards humanity.".to_string())
//! .build();
//!
//! // Send the completion request and get the completion response
//! let response = model.completion(request)
//! .await
//! .expect("Failed to get completion response");
//!
//! // Handle the completion response
//! match completion_response.choice {
//! ModelChoice::Message(message) => {
//! // Handle the completion response as a message
//! println!("Received message: {}", message);
//! }
//! ModelChoice::ToolCall(tool_name, tool_params) => {
//! // Handle the completion response as a tool call
//! println!("Received tool call: {} {:?}", tool_name, tool_params);
//! }
//! }
//! ```
//!
//! For more information on how to use the completion functionality, refer to the documentation of
//! the individual traits, structs, and enums defined in this module.
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
@ -66,16 +116,17 @@ pub struct ToolDefinition {
// ================================================================
// Implementations
// ================================================================
/// Trait defining a high-level LLM chat interface (i.e.: prompt in, response out).
/// Trait defining a high-level LLM on-shot prompt interface (i.e.: prompt in, response out).
pub trait Prompt: Send + Sync {
fn prompt(
&self,
prompt: &str,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send {
self.chat(prompt, Vec::new())
}
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
}
/// Send a prompt to the completion endpoint along with a chat history.
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat hiroty in, response out).
pub trait Chat: Send + Sync {
/// Send a one-shot prompt to the completion endpoint.
/// If the response is a message, then it is returned as a string. If the response
/// is a tool call, then the tool is called and the result is returned as a string.
fn chat(
@ -91,10 +142,13 @@ pub trait Completion<M: CompletionModel> {
/// This function is meant to be called by the user to further customize the
/// request at prompt time before sending it.
///
/// IMPORTANT: The CompletionModel that implements this trait will already
/// populate fields (the exact fields depend on the model) in the builder.
/// ❗IMPORTANT: The type that implements this trait might have already
/// populated fields in the builder (the exact fields depend on the type).
/// For fields that have already been set by the model, calling the corresponding
/// method on the builder will overwrite the value set by the model.
///
/// For example, the request builder returned by `Agent::completion` will already
/// contain the `preamble` provided when creating the agent.
fn completion(
&self,
prompt: &str,
@ -102,25 +156,39 @@ pub trait Completion<M: CompletionModel> {
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
}
/// General completion response struct that contains the high-level completion choice
/// and the raw response.
#[derive(Debug)]
pub struct CompletionResponse<T> {
/// The completion choice returned by the completion model provider
pub choice: ModelChoice,
/// The raw response returned by the completion model provider
pub raw_response: T,
}
/// Enum representing the high-level completion choice returned by the completion model provider.
#[derive(Debug)]
pub enum ModelChoice {
/// Represents a completion response as a message
Message(String),
/// Represents a completion response as a tool call of the form
/// `ToolCall(function_name, function_params)`.
ToolCall(String, serde_json::Value),
}
/// Trait defining a completion model that can be used to generate completion responses.
/// This trait is meant to be implemented by the user to define a custom completion model,
/// either from a third party provider (e.g.: OpenAI) or locally.
pub trait CompletionModel: Clone + Send + Sync {
type T: Send + Sync;
/// The raw response type returned by the underlying completion model
type Response: Send + Sync;
/// Generates a completion response for the given completion request.
fn completion(
&self,
request: CompletionRequest,
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::T>, CompletionError>> + Send;
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
+ Send;
fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt.to_string())
@ -130,8 +198,8 @@ pub trait CompletionModel: Clone + Send + Sync {
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::T>, CompletionError>> + Send
{
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
+ Send {
async move {
self.completion_request(prompt)
.messages(chat_history)
@ -141,16 +209,25 @@ pub trait CompletionModel: Clone + Send + Sync {
}
}
/// Struct representing a general completion request that can be sent to a completion model provider.
pub struct CompletionRequest {
pub temperature: Option<f64>,
/// The prompt to be sent to the completion model provider
pub prompt: String,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The chat history to be sent to the completion model provider
pub chat_history: Vec<Message>,
/// The documents to be sent to the completion model provider
pub documents: Vec<Document>,
/// The tools to be sent to the completion model provider
pub tools: Vec<ToolDefinition>,
/// The temperature to be sent to the completion model provider
pub temperature: Option<f64>,
/// Additional provider-specific parameters to be sent to the completion model provider
pub additional_params: Option<serde_json::Value>,
}
/// Builder struct for constructing a completion request.
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: String,
@ -176,44 +253,52 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
}
}
/// Sets the preamble for the completion request.
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = Some(preamble);
self
}
/// Adds a message to the chat history for the completion request.
pub fn message(mut self, message: Message) -> Self {
self.chat_history.push(message);
self
}
/// Adds a list of messages to the chat history for the completion request.
pub fn messages(self, messages: Vec<Message>) -> Self {
messages
.into_iter()
.fold(self, |builder, msg| builder.message(msg))
}
/// Adds a document to the completion request.
pub fn document(mut self, document: Document) -> Self {
self.documents.push(document);
self
}
/// Adds a list of documents to the completion request.
pub fn documents(self, documents: Vec<Document>) -> Self {
documents
.into_iter()
.fold(self, |builder, doc| builder.document(doc))
}
/// Adds a tool to the completion request.
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
/// Adds a list of tools to the completion request.
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.tool(tool))
}
/// Adds additional parameters to the completion request.
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
@ -226,21 +311,25 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}
/// Sets the additional parameters for the completion request.
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
/// Sets the temperature for the completion request.
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Sets the temperature for the completion request.
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
prompt: self.prompt,
@ -253,7 +342,8 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
}
}
pub async fn send(self) -> Result<CompletionResponse<M::T>, CompletionError> {
/// Sends the completion request to the completion model provider and returns the completion response.
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
let model = self.model.clone();
model.completion(self.build()).await
}

View File

@ -1,3 +1,44 @@
//! This module provides functionality for working with embeddings and embedding models.
//! Embeddings are numerical representations of documents or other objects, typically used in
//! natural language processing (NLP) tasks such as text classification, information retrieval,
//! and document similarity.
//!
//! The module defines the `EmbeddingModel` trait, which represents an embedding model that can
//! generate embeddings for documents. It also provides an implementation of the `EmbeddingsBuilder`
//! struct, which allows users to build collections of document embeddings using different embedding
//! models and document sources.
//!
//! The module also defines the `Embedding` struct, which represents a single document embedding,
//! and the `DocumentEmbeddings` struct, which represents a document along with its associated
//! embeddings. These structs are used to store and manipulate collections of document embeddings.
//!
//! Finally, the module defines the `EmbeddingError` enum, which represents various errors that
//! can occur during embedding generation or processing.
//!
//! Example usage:
//!
//! ```rust
//! use rig::providers::openai::{Client, self};
//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder};
//!
//! // Initialize the OpenAI client
//! let openai = Client::new("your-openai-api-key");
//!
//! // Create an instance of the `text-embedding-ada-002` model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create an embeddings builder and add documents
//! let embeddings = EmbeddingsBuilder::new(embedding_model)
//! .simple_document("doc1", "This is the first document.")
//! .simple_document("doc2", "This is the second document.")
//! .build()
//! .await
//! .expect("Failed to build embeddings.");
//!
//! // Use the generated embeddings
//! // ...
//! ```
use std::{cmp::max, collections::HashMap};
use futures::{stream, StreamExt, TryStreamExt};
@ -28,6 +69,7 @@ pub enum EmbeddingError {
ProviderError(String),
}
/// Trait for embedding models that can generate embeddings for documents.
pub trait EmbeddingModel: Clone + Sync + Send {
const MAX_DOCUMENTS: usize;
@ -56,9 +98,12 @@ pub trait EmbeddingModel: Clone + Sync + Send {
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
}
/// Struct that holds a single document and its embedding.
#[derive(Clone, Default, Deserialize, Serialize)]
pub struct Embedding {
/// The document that was embedded
pub document: String,
/// The embedding vector
pub vec: Vec<f64>,
}
@ -85,7 +130,12 @@ impl Embedding {
}
}
/// Struct to store the document and its embeddings.
/// Struct that holds a document and its embeddings.
///
/// The struct is designed to model any kind of documents that can be serialized to JSON
/// (including a simple string). Moreover, it can hold multiple embeddings for the same
/// document, thus allowing a large or non-text document to be "ragged" from various
/// smaller text documents.
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct DocumentEmbeddings {
#[serde(rename = "_id")]

View File

@ -22,8 +22,9 @@ pub enum ExtractionError {
PromptError(#[from] PromptError),
}
/// Extractor for structured data from text
pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
pub agent: Agent<M>,
agent: Agent<M>,
_t: PhantomData<T>,
}
@ -42,6 +43,7 @@ where
}
}
/// Builder for the Extractor
pub struct ExtractorBuilder<
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
M: CompletionModel,
@ -50,7 +52,7 @@ pub struct ExtractorBuilder<
_t: PhantomData<T>,
}
impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
ExtractorBuilder<T, M>
{
pub fn new(model: M) -> Self {
@ -58,7 +60,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
agent_builder: AgentBuilder::new(model)
.preamble("\
You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n
You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
Use the `submit` function to submit the structured data.\n\
Be sure to fill out every field and ALWAYS CALL THE `submit` function, event with default values!!!.
")
@ -81,6 +83,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
self
}
/// Build the Extractor
pub fn build(self) -> Extractor<M, T> {
Extractor {
agent: self.agent_builder.build(),
@ -96,13 +99,13 @@ struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
#[derive(Debug, thiserror::Error)]
#[error("SubmitError")]
pub struct SubmitError;
struct SubmitError;
impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> Tool for SubmitTool<T> {
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
const NAME: &'static str = "submit";
type Error = SubmitError;
type Args = String;
type Output = String;
type Args = T;
type Output = T;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
@ -113,7 +116,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> Tool for SubmitTool<
}
}
async fn call(&self, data: Self::Args) -> Result<String, Self::Error> {
async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
Ok(data)
}
}

View File

@ -1,3 +1,27 @@
//!
//!
//!
//! Simple example:
//! ```
//! use rig::{completion::Prompt, providers::openai};
//!
//! #[tokio::main]
//! async fn main() {
//! // Create OpenAI client and model
//! let openai_client = openai::Client::from_env();
//!
//! let gpt4 = openai_client.model("gpt-4").build();
//!
//! // Prompt the model and print its response
//! let response = gpt4
//! .prompt("Who are you?")
//! .await
//! .expect("Failed to prompt GPT-4");
//!
//! println!("GPT-4: {response}");
//! }
//! ```
pub mod agent;
pub mod cli_chatbot;
pub mod completion;

View File

@ -1,12 +1,39 @@
//! This module contains the implementation of the `Agent` struct and its builder.
//!
//! The `Agent` struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
//! by providing prompts and chat history.
//!
//! The `AgentBuilder` struct provides a builder pattern for creating instances of the `Agent` struct.
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
//! before building the agent.
//!
//! # Example
//! ```rust
//! use rig::{completion::Prompt, providers::openai};
//!
//! let openai_client = openai::Client::from_env();
//!
//! // Configure the model
//! let model = client.model("gpt-4o")
//! .temperature(0.8)
//! .build();
//!
//! // Use the model for completions and prompts
//! let completion_req_builder = model.completion("Prompt", chat_history).await;
//! let chat_response = model.chat("Prompt", chat_history).await;
//! ```
//!
//! For more information on how to use the `Agent` struct and its builder, refer to the documentation of the respective structs and methods.
use crate::completion::{
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
Message, ModelChoice, Prompt, PromptError,
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Message, ModelChoice, Prompt, PromptError,
};
/// A model that can be used to prompt completions from a completion model.
/// This is the simplest building block for creating an LLM powered application.
pub struct Model<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
model: M,
/// Temperature of the model
temperature: Option<f64>,
@ -27,6 +54,12 @@ impl<M: CompletionModel> Completion<M> for Model<M> {
}
impl<M: CompletionModel> Prompt for Model<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Chat for Model<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
@ -43,12 +76,28 @@ impl<M: CompletionModel> Prompt for Model<M> {
}
}
/// A builder for creating a model
///
/// # Example
/// ```
/// use rig::{providers::openai, model::ModelBuilder};
///
/// let openai_client = openai::Client::from_env();
///
/// let gpt4 = openai_client.completion_model("gpt-4");
///
/// // Configure the model
/// let model = ModelBuilder::new(model)
/// .temperature(0.8)
/// .build();
/// ```
pub struct ModelBuilder<M: CompletionModel> {
model: M,
pub temperature: Option<f64>,
temperature: Option<f64>,
}
impl<M: CompletionModel> ModelBuilder<M> {
/// Create a new model builder
pub fn new(model: M) -> Self {
Self {
model,
@ -56,11 +105,19 @@ impl<M: CompletionModel> ModelBuilder<M> {
}
}
/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Set the temperature of the model (set to None to use the default temperature of the model)
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
/// Build the model
pub fn build(self) -> Model<M> {
Model {
model: self.model,

View File

@ -1,3 +1,10 @@
//! Cohere API client and Rig integration
//!
//! # Example
//! ```
//! use rig::{providers::cohere, model::ModelBuilder};
//!
//! ```
use std::collections::HashMap;
use crate::{
@ -75,7 +82,7 @@ impl Client {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync>(
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
@ -221,9 +228,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
})
.collect())
}
ApiResponse::Err(error) => {
return Err(EmbeddingError::ProviderError(error.message));
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
}
}
@ -467,7 +472,7 @@ impl CompletionModel {
}
impl completion::CompletionModel for CompletionModel {
type T = CompletionResponse;
type Response = CompletionResponse;
async fn completion(
&self,

View File

@ -1,3 +1,4 @@
//! OpenAI API client and Rig integration
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
@ -24,10 +25,12 @@ pub struct Client {
}
impl Client {
/// Create a new OpenAI client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENAI_API_BASE_URL)
}
/// Create a new OpenAI client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
@ -47,32 +50,116 @@ impl Client {
}
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
/// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
Self::new(&api_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create an embedding model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model)
}
/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let gpt4 = openai.completion_model(openai::GPT_4);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create a model builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let completion_model = openai.model(openai::GPT_4)
/// .temperature(0.0)
/// .build();
/// ```
pub fn model(&self, model: &str) -> ModelBuilder<CompletionModel> {
ModelBuilder::new(self.completion_model(model))
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let agent = openai.agent(openai::GPT_4)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync>(
/// Create an extractor builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let extractor = openai.extractor::<MyStruct>(openai::GPT_4)
/// .build();
/// ```
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
@ -374,7 +461,7 @@ impl CompletionModel {
}
impl completion::CompletionModel for CompletionModel {
type T = CompletionResponse;
type Response = CompletionResponse;
async fn completion(
&self,

View File

@ -4,8 +4,8 @@ use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
Document, Message, ModelChoice, Prompt, PromptError,
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet, ToolSetError},
vector_store::{NoIndex, VectorStoreError, VectorStoreIndex},
@ -126,6 +126,12 @@ impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Completion<M>
}
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Prompt for RagAgent<M, C, T> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Chat for RagAgent<M, C, T> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
@ -146,6 +152,28 @@ impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> RagAgent<M, C
}
}
/// Builder for creating a RAG agent
///
/// # Example
/// ```
/// use rig::{providers::openai, rag_agent::RagAgentBuilder};
///
/// let openai_client = openai::Client::from_env();
///
/// let model = openai_client.completion_model("gpt-4");
///
/// // Configure the agent
/// let agent = RagAgentBuilder::new(model)
/// .preamble("System prompt")
/// .static_context("Context document 1")
/// .static_context("Context document 2")
/// .dynamic_context(2, vector_index)
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct RagAgentBuilder<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,

View File

@ -16,12 +16,66 @@ pub enum ToolError {
}
/// Trait that represents a simple LLM tool
///
/// # Example
/// ```
/// use rig::tool::{ToolSet, Tool};
///
/// #[derive(Deserialize)]
/// struct AddArgs {
/// x: i32,
/// y: i32,
/// }
///
/// #[derive(Debug, thiserror::Error)]
/// #[error("Math error")]
/// struct MathError;
///
/// #[derive(Deserialize, Serialize)]
/// struct Adder;
///
/// impl Tool for Adder {
/// const NAME: &'static str = "add";
///
/// type Error = MathError;
/// type Args = AddArgs;
/// type Output = i32;
///
/// async fn definition(&self, _prompt: String) -> ToolDefinition {
/// ToolDefinition {
/// name: "add".to_string(),
/// description: "Add x and y together".to_string(),
/// parameters: json!({
/// "type": "object",
/// "properties": {
/// "x": {
/// "type": "number",
/// "description": "The first number to add"
/// },
/// "y": {
/// "type": "number",
/// "description": "The second number to add"
/// }
/// }
/// })
/// }
/// }
///
/// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
/// let result = args.x + args.y;
/// Ok(result)
/// }
/// }
/// ```
pub trait Tool: Sized + Send + Sync {
/// The name of the tool. This name should be unique.
const NAME: &'static str;
/// The error type of the tool.
type Error: std::error::Error + Send + Sync + 'static;
/// The arguments type of the tool.
type Args: for<'a> Deserialize<'a> + Send + Sync;
/// The output type of the tool.
type Output: Serialize;
/// A method returning the name of the tool.
@ -180,7 +234,8 @@ pub struct ToolSet {
}
impl ToolSet {
pub fn new(tools: Vec<impl ToolDyn + 'static>) -> Self {
/// Create a new ToolSet from a list of tools
pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
let mut toolset = Self::default();
tools.into_iter().for_each(|tool| {
toolset.add_tool(tool);
@ -188,19 +243,23 @@ impl ToolSet {
toolset
}
/// Create a toolset builder
pub fn builder() -> ToolSetBuilder {
ToolSetBuilder::default()
}
/// Check if the toolset contains a tool with the given name
pub fn contains(&self, toolname: &str) -> bool {
self.tools.contains_key(toolname)
}
/// Add a tool to the toolset
pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
self.tools
.insert(tool.name(), ToolType::Simple(Box::new(tool)));
}
/// Merge another toolset into this one
pub fn add_tools(&mut self, toolset: ToolSet) {
self.tools.extend(toolset.tools);
}
@ -209,9 +268,14 @@ impl ToolSet {
self.tools.get(toolname)
}
/// Call a tool with the given name and arguments
///
/// # Example
/// ```
/// ```
pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
if let Some(tool) = self.tools.get(toolname) {
tracing::info!(target: "ai",
tracing::info!(target: "rig",
"Calling tool {toolname} with args:\n{}",
serde_json::to_string_pretty(&args).unwrap_or_else(|_| args.clone())
);
@ -221,6 +285,7 @@ impl ToolSet {
}
}
/// Get the documents of all the tools in the toolset
pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
let mut docs = Vec::new();
for tool in self.tools.values() {

View File

@ -1,3 +1,4 @@
//! In-memory implementation of a vector store.
use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap},
@ -9,6 +10,8 @@ use serde::{Deserialize, Serialize};
use super::{VectorStore, VectorStoreError, VectorStoreIndex};
use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder};
/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings
/// in-memory using a HashMap.
#[derive(Clone, Default, Deserialize, Serialize)]
pub struct InMemoryVectorStore {
/// The embeddings are stored in a HashMap with the document ID as the key.
@ -159,10 +162,6 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> {
}
impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> {
async fn embed_document(&self, document: &str) -> Result<Embedding, VectorStoreError> {
Ok(self.model.embed_document(document).await?)
}
async fn top_n_from_query(
&self,
query: &str,
@ -208,7 +207,7 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI
}
// Log selected tools with their distances
tracing::info!(target: "ai",
tracing::info!(target: "rig",
"Selected documents: {}",
docs.iter()
.map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance))

View File

@ -17,36 +17,38 @@ pub enum VectorStoreError {
DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>),
}
/// Trait for vector stores
pub trait VectorStore: Send + Sync {
/// Query type for the vector store
type Q;
/// Add a list of documents to the vector store
fn add_documents(
&mut self,
documents: Vec<DocumentEmbeddings>,
) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
/// Get the embeddings of a document by its id
fn get_document_embeddings(
&self,
id: &str,
) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send;
/// Get the document by its id and deserialize it into the given type
fn get_document<T: for<'a> Deserialize<'a>>(
&self,
id: &str,
) -> impl std::future::Future<Output = Result<Option<T>, VectorStoreError>> + Send;
/// Get the document by a query and deserialize it into the given type
fn get_document_by_query(
&self,
query: Self::Q,
) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send;
}
/// Trait for vector store indexes
pub trait VectorStoreIndex: Send + Sync {
fn embed_document(
&self,
document: &str,
) -> impl std::future::Future<Output = Result<Embedding, VectorStoreError>> + Send;
/// Get the top n documents based on the distance to the given embedding.
/// The distance is calculated as the cosine distance between the prompt and
/// the document embedding.
@ -135,10 +137,6 @@ pub trait VectorStoreIndex: Send + Sync {
pub struct NoIndex;
impl VectorStoreIndex for NoIndex {
async fn embed_document(&self, _document: &str) -> Result<Embedding, VectorStoreError> {
Ok(Embedding::default())
}
async fn top_n_from_query(
&self,
_query: &str,

View File

@ -114,10 +114,6 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
}
impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M> {
async fn embed_document(&self, document: &str) -> Result<Embedding, VectorStoreError> {
Ok(self.model.embed_document(document).await?)
}
async fn top_n_from_query(
&self,
query: &str,
@ -166,7 +162,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV
results.push((score, document));
}
tracing::info!(target: "ai",
tracing::info!(target: "rig",
"Selected documents: {}",
results.iter()
.map(|(distance, doc)| format!("{} ({})", doc.id, distance))