mirror of https://github.com/0xplaygrounds/rig
refactor: Split Chat and Prompt traits + other stuff
This commit is contained in:
parent
7585f0e742
commit
c8a4ae39f7
|
@ -15,16 +15,8 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
.build();
|
||||
|
||||
// let client = providers::cohere::Client::new(
|
||||
// &env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"),
|
||||
// );
|
||||
// let comedian_agent = client
|
||||
// .agent("command-r")
|
||||
// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
// .build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = comedian_agent.chat("Entertain me!", vec![]).await?;
|
||||
let response = comedian_agent.prompt("Entertain me!").await?;
|
||||
println!("{}", response);
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -20,9 +20,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = agent
|
||||
.chat("What does \"glarb-glarb\" mean?", vec![])
|
||||
.await?;
|
||||
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;
|
||||
|
||||
println!("{}", response);
|
||||
|
||||
|
|
|
@ -28,10 +28,10 @@ impl Tool for Adder {
|
|||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "add",
|
||||
"description": "Add x and y together",
|
||||
"parameters": {
|
||||
ToolDefinition {
|
||||
name: "add".to_string(),
|
||||
description: "Add x and y together".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
|
@ -43,9 +43,8 @@ impl Tool for Adder {
|
|||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
|
@ -120,15 +119,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
println!("Calculate 2 - 5");
|
||||
println!(
|
||||
"GPT-4: {}",
|
||||
gpt4_calculator_agent
|
||||
.chat("Calculate 2 - 5", vec![])
|
||||
.await?
|
||||
gpt4_calculator_agent.prompt("Calculate 2 - 5").await?
|
||||
);
|
||||
println!(
|
||||
"Coral: {}",
|
||||
coral_calculator_agent
|
||||
.chat("Calculate 2 - 5", vec![])
|
||||
.await?
|
||||
coral_calculator_agent.prompt("Calculate 2 - 5").await?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -22,9 +22,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
// Prompt the model and print the response
|
||||
// We use `prompt` to get a simple response from the model as a String
|
||||
let response = klimadao_agent
|
||||
.chat("Tell me about BCT tokens?", vec![])
|
||||
.await?;
|
||||
let response = klimadao_agent.prompt("Tell me about BCT tokens?").await?;
|
||||
|
||||
println!("\n\nCoral: {:?}", response);
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env;
|
|||
use anyhow::Result;
|
||||
use rig::{
|
||||
agent::Agent,
|
||||
completion::{Message, Prompt},
|
||||
completion::{Chat, Message},
|
||||
providers::{cohere, openai},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use std::env;
|
||||
|
||||
use rig::providers::{cohere::Client as CohereClient, openai::Client as OpenAIClient};
|
||||
use rig::providers::openai;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -18,8 +16,7 @@ struct Person {
|
|||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create OpenAI client
|
||||
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
let openai_client = OpenAIClient::new(&openai_api_key);
|
||||
let openai_client = openai::Client::from_env();
|
||||
|
||||
// Create extractor
|
||||
let data_extractor = openai_client.extractor::<Person>("gpt-4").build();
|
||||
|
@ -30,18 +27,5 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
println!("GPT-4: {}", serde_json::to_string_pretty(&person).unwrap());
|
||||
|
||||
// Create Cohere client
|
||||
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
|
||||
let cohere_client = CohereClient::new(&cohere_api_key);
|
||||
|
||||
// Create extractor
|
||||
let data_extractor = cohere_client.extractor::<Person>("command-r").build();
|
||||
|
||||
let person = data_extractor
|
||||
.extract("Hello my name is John Doe! I am a software engineer.")
|
||||
.await?;
|
||||
|
||||
println!("Coral: {}", serde_json::to_string_pretty(&person).unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env;
|
|||
use rig::{
|
||||
agent::{Agent, AgentBuilder},
|
||||
cli_chatbot::cli_chatbot,
|
||||
completion::{CompletionModel, Message, Prompt, PromptError},
|
||||
completion::{Chat, CompletionModel, Message, PromptError},
|
||||
model::{Model, ModelBuilder},
|
||||
providers::openai::Client as OpenAIClient,
|
||||
};
|
||||
|
@ -34,7 +34,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for EnglishTranslator<M> {
|
||||
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
|
||||
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
|
||||
// Translate the prompt using the translator agent
|
||||
let translated_prompt = self
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env;
|
|||
use rig::{
|
||||
completion::Prompt,
|
||||
embeddings::EmbeddingsBuilder,
|
||||
providers::openai::Client,
|
||||
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
|
||||
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||
};
|
||||
|
||||
|
@ -13,7 +13,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
let openai_client = Client::new(&openai_api_key);
|
||||
|
||||
let embedding_model = openai_client.embedding_model("text-embedding-ada-002");
|
||||
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
|
||||
|
||||
// Create vector store, compute embeddings and load them in the store
|
||||
let mut vector_store = InMemoryVectorStore::default();
|
||||
|
@ -39,9 +39,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = rag_agent
|
||||
.chat("What does \"glarb-glarb\" mean?", vec![])
|
||||
.await?;
|
||||
let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await?;
|
||||
|
||||
println!("{}", response);
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = calculator_rag.chat("Calculate 3 - 7", vec![]).await?;
|
||||
let response = calculator_rag.prompt("Calculate 3 - 7").await?;
|
||||
println!("{}", response);
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
use rig::providers::openai;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
|
||||
/// An enum representing the sentiment of a document
|
||||
enum Sentiment {
|
||||
Positive,
|
||||
Negative,
|
||||
Neutral,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
|
||||
struct DocumentSentiment {
|
||||
/// The sentiment of the document
|
||||
sentiment: Sentiment,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Create OpenAI client
|
||||
let openai_client = openai::Client::from_env();
|
||||
|
||||
// Create extractor
|
||||
let data_extractor = openai_client
|
||||
.extractor::<DocumentSentiment>("gpt-4")
|
||||
.build();
|
||||
|
||||
let sentiment = data_extractor
|
||||
.extract("I am happy")
|
||||
.await
|
||||
.expect("Failed to extract sentiment");
|
||||
|
||||
println!("GPT-4: {:?}", sentiment);
|
||||
}
|
|
@ -1,28 +1,17 @@
|
|||
use std::env;
|
||||
|
||||
use rig::{
|
||||
completion::Prompt,
|
||||
providers::{cohere, openai},
|
||||
};
|
||||
use rig::{completion::Prompt, providers::openai};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
async fn main() {
|
||||
// Create OpenAI client and model
|
||||
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
let openai_client = openai::Client::new(&openai_api_key);
|
||||
let openai_client = openai::Client::from_env();
|
||||
|
||||
let gpt4 = openai_client.model("gpt-4").temperature(0.0).build();
|
||||
let gpt4 = openai_client.model("gpt-4").build();
|
||||
|
||||
// Create Cohere client and model
|
||||
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
|
||||
let cohere_client = cohere::Client::new(&cohere_api_key);
|
||||
// Prompt the model and print its response
|
||||
let response = gpt4
|
||||
.prompt("Who are you?")
|
||||
.await
|
||||
.expect("Failed to prompt GPT-4");
|
||||
|
||||
let command_r = cohere_client.model("command-r").temperature(0.0).build();
|
||||
|
||||
// Prompt the models and print their response
|
||||
println!("Question: Who are you?");
|
||||
println!("GPT-4: {:?}", gpt4.chat("Who are you?", vec![]).await?);
|
||||
println!("Coral: {:?}", command_r.chat("Who are you?", vec![]).await?);
|
||||
|
||||
Ok(())
|
||||
println!("GPT-4: {response}");
|
||||
}
|
||||
|
|
|
@ -1,20 +1,71 @@
|
|||
//! This module contains the implementation of the `Agent` struct and its builder.
|
||||
//!
|
||||
//! The `Agent` struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
|
||||
//! by providing prompts and chat history.
|
||||
//!
|
||||
//! The `AgentBuilder` struct provides a builder pattern for creating instances of the `Agent` struct.
|
||||
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! ```rust
|
||||
//! use rig::{completion::Prompt, providers::openai};
|
||||
//!
|
||||
//! let openai_client = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the agent
|
||||
//! let agent = client.agent("gpt-4o")
|
||||
//! .preamble("System prompt")
|
||||
//! .context("Context document 1")
|
||||
//! .context("Context document 2")
|
||||
//! .tool(tool1)
|
||||
//! .tool(tool2)
|
||||
//! .temperature(0.8)
|
||||
//! .additional_params(json!({"foo": "bar"}))
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the agent for completions and prompts
|
||||
//! let completion_req_builder = agent.completion("Prompt", chat_history).await;
|
||||
//! let chat_response = agent.chat("Prompt", chat_history).await;
|
||||
//! ```
|
||||
//!
|
||||
//! For more information on how to use the `Agent` struct and its builder, refer to the documentation of the respective structs and methods.
|
||||
use std::collections::HashMap;
|
||||
|
||||
use futures::{stream, StreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
|
||||
Document, Message, ModelChoice, Prompt, PromptError,
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
|
||||
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
|
||||
},
|
||||
tool::{Tool, ToolSet},
|
||||
};
|
||||
|
||||
/// Struct reprensenting an LLM agent. An agent is an LLM model
|
||||
/// combined with static context (i.e.: always inserted at the top
|
||||
/// of the chat history before any use prompts) and static tools.
|
||||
/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{completion::Prompt, providers::openai};
|
||||
///
|
||||
/// let openai_client = openai::Client::from_env();
|
||||
///
|
||||
/// let comedian_agent = client
|
||||
/// .agent("gpt-4o")
|
||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
/// .temperature(0.9)
|
||||
/// .build();
|
||||
///
|
||||
/// let response = comedian_agent.prompt("Entertain me!")
|
||||
/// .await
|
||||
/// .expect("Failed to prompt GPT-4");
|
||||
/// ```
|
||||
pub struct Agent<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: String,
|
||||
|
@ -54,7 +105,7 @@ impl<M: CompletionModel> Agent<M> {
|
|||
additional_props: HashMap::new(),
|
||||
})
|
||||
.collect(),
|
||||
tools: ToolSet::new(static_tools),
|
||||
tools: ToolSet::from_tools(static_tools),
|
||||
static_tools: static_tools_ids,
|
||||
temperature,
|
||||
additional_params,
|
||||
|
@ -73,7 +124,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
|
|||
if let Some(tool) = self.tools.get(toolname) {
|
||||
Some(tool.definition(prompt.into()).await)
|
||||
} else {
|
||||
tracing::error!(target: "ai", "Agent static tool {} not found", toolname);
|
||||
tracing::error!(target: "rig", "Agent static tool {} not found", toolname);
|
||||
None
|
||||
}
|
||||
})
|
||||
|
@ -93,6 +144,12 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
|
|||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
|
||||
self.chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Chat for Agent<M> {
|
||||
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
|
||||
match self.completion(prompt, chat_history).await?.send().await? {
|
||||
CompletionResponse {
|
||||
|
@ -107,6 +164,27 @@ impl<M: CompletionModel> Prompt for Agent<M> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A builder for creating an agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||
///
|
||||
/// let openai_client = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4 = openai_client.completion_model("gpt-4");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = AgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .context("Context document 1")
|
||||
/// .context("Context document 2")
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct AgentBuilder<M: CompletionModel> {
|
||||
model: M,
|
||||
preamble: Option<String>,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::io::{self, Write};
|
||||
|
||||
use crate::completion::{Message, Prompt, PromptError};
|
||||
use crate::completion::{Chat, Message, PromptError};
|
||||
|
||||
pub async fn cli_chatbot(chatbot: impl Prompt) -> Result<(), PromptError> {
|
||||
pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
|
||||
let stdin = io::stdin();
|
||||
let mut stdout = io::stdout();
|
||||
let mut chat_log = vec![];
|
||||
|
|
|
@ -1,3 +1,53 @@
|
|||
//! This module contains the implementation of the completion functionality for the LLM (Large Language
|
||||
//! Model) chat interface. It provides traits, structs, and enums for generating completion requests,
|
||||
//! handling completion responses, and defining completion models.
|
||||
//!
|
||||
//! The main traits defined in this module are:
|
||||
//! - `Prompt`: Defines a high-level LLM chat interface for prompting and receiving responses.
|
||||
//! - `Completion`: Defines a low-level LLM completion interface for generating completion requests.
|
||||
//! - `CompletionModel`: Defines a completion model that can be used to generate completion responses.
|
||||
//!
|
||||
//! The module also provides various structs and enums for representing generic completion requests,
|
||||
//! responses, and errors.
|
||||
//!
|
||||
//! Example Usage:
|
||||
//!
|
||||
//! ```rust
|
||||
//! use rig::providers::openai::{Client, self};
|
||||
//! use rig::completion::*;
|
||||
//!
|
||||
//! // Initialize the OpenAI client and a completion model
|
||||
//! let openai = Client::new("your-openai-api-key");
|
||||
//!
|
||||
//! let model = openai.model(openai::GPT_4).build();
|
||||
//!
|
||||
//!
|
||||
//! // Create the completion request
|
||||
//! let builder = model.completion_request("Who are you?");
|
||||
//! .preamble(
|
||||
//! "You are Marvin, an extremely smart but depressed robot who is nonetheless helpful towards humanity.".to_string())
|
||||
//! .build();
|
||||
//!
|
||||
//! // Send the completion request and get the completion response
|
||||
//! let response = model.completion(request)
|
||||
//! .await
|
||||
//! .expect("Failed to get completion response");
|
||||
//!
|
||||
//! // Handle the completion response
|
||||
//! match completion_response.choice {
|
||||
//! ModelChoice::Message(message) => {
|
||||
//! // Handle the completion response as a message
|
||||
//! println!("Received message: {}", message);
|
||||
//! }
|
||||
//! ModelChoice::ToolCall(tool_name, tool_params) => {
|
||||
//! // Handle the completion response as a tool call
|
||||
//! println!("Received tool call: {} {:?}", tool_name, tool_params);
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! For more information on how to use the completion functionality, refer to the documentation of
|
||||
//! the individual traits, structs, and enums defined in this module.
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -66,16 +116,17 @@ pub struct ToolDefinition {
|
|||
// ================================================================
|
||||
// Implementations
|
||||
// ================================================================
|
||||
/// Trait defining a high-level LLM chat interface (i.e.: prompt in, response out).
|
||||
/// Trait defining a high-level LLM on-shot prompt interface (i.e.: prompt in, response out).
|
||||
pub trait Prompt: Send + Sync {
|
||||
fn prompt(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send {
|
||||
self.chat(prompt, Vec::new())
|
||||
}
|
||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
||||
}
|
||||
|
||||
/// Send a prompt to the completion endpoint along with a chat history.
|
||||
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat hiroty in, response out).
|
||||
pub trait Chat: Send + Sync {
|
||||
/// Send a one-shot prompt to the completion endpoint.
|
||||
/// If the response is a message, then it is returned as a string. If the response
|
||||
/// is a tool call, then the tool is called and the result is returned as a string.
|
||||
fn chat(
|
||||
|
@ -91,10 +142,13 @@ pub trait Completion<M: CompletionModel> {
|
|||
/// This function is meant to be called by the user to further customize the
|
||||
/// request at prompt time before sending it.
|
||||
///
|
||||
/// IMPORTANT: The CompletionModel that implements this trait will already
|
||||
/// populate fields (the exact fields depend on the model) in the builder.
|
||||
/// ❗IMPORTANT: The type that implements this trait might have already
|
||||
/// populated fields in the builder (the exact fields depend on the type).
|
||||
/// For fields that have already been set by the model, calling the corresponding
|
||||
/// method on the builder will overwrite the value set by the model.
|
||||
///
|
||||
/// For example, the request builder returned by `Agent::completion` will already
|
||||
/// contain the `preamble` provided when creating the agent.
|
||||
fn completion(
|
||||
&self,
|
||||
prompt: &str,
|
||||
|
@ -102,25 +156,39 @@ pub trait Completion<M: CompletionModel> {
|
|||
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
|
||||
}
|
||||
|
||||
/// General completion response struct that contains the high-level completion choice
|
||||
/// and the raw response.
|
||||
#[derive(Debug)]
|
||||
pub struct CompletionResponse<T> {
|
||||
/// The completion choice returned by the completion model provider
|
||||
pub choice: ModelChoice,
|
||||
/// The raw response returned by the completion model provider
|
||||
pub raw_response: T,
|
||||
}
|
||||
|
||||
/// Enum representing the high-level completion choice returned by the completion model provider.
|
||||
#[derive(Debug)]
|
||||
pub enum ModelChoice {
|
||||
/// Represents a completion response as a message
|
||||
Message(String),
|
||||
/// Represents a completion response as a tool call of the form
|
||||
/// `ToolCall(function_name, function_params)`.
|
||||
ToolCall(String, serde_json::Value),
|
||||
}
|
||||
|
||||
/// Trait defining a completion model that can be used to generate completion responses.
|
||||
/// This trait is meant to be implemented by the user to define a custom completion model,
|
||||
/// either from a third party provider (e.g.: OpenAI) or locally.
|
||||
pub trait CompletionModel: Clone + Send + Sync {
|
||||
type T: Send + Sync;
|
||||
/// The raw response type returned by the underlying completion model
|
||||
type Response: Send + Sync;
|
||||
|
||||
/// Generates a completion response for the given completion request.
|
||||
fn completion(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::T>, CompletionError>> + Send;
|
||||
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
|
||||
+ Send;
|
||||
|
||||
fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
|
||||
CompletionRequestBuilder::new(self.clone(), prompt.to_string())
|
||||
|
@ -130,8 +198,8 @@ pub trait CompletionModel: Clone + Send + Sync {
|
|||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::T>, CompletionError>> + Send
|
||||
{
|
||||
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
|
||||
+ Send {
|
||||
async move {
|
||||
self.completion_request(prompt)
|
||||
.messages(chat_history)
|
||||
|
@ -141,16 +209,25 @@ pub trait CompletionModel: Clone + Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
/// Struct representing a general completion request that can be sent to a completion model provider.
|
||||
pub struct CompletionRequest {
|
||||
pub temperature: Option<f64>,
|
||||
/// The prompt to be sent to the completion model provider
|
||||
pub prompt: String,
|
||||
/// The preamble to be sent to the completion model provider
|
||||
pub preamble: Option<String>,
|
||||
/// The chat history to be sent to the completion model provider
|
||||
pub chat_history: Vec<Message>,
|
||||
/// The documents to be sent to the completion model provider
|
||||
pub documents: Vec<Document>,
|
||||
/// The tools to be sent to the completion model provider
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
/// The temperature to be sent to the completion model provider
|
||||
pub temperature: Option<f64>,
|
||||
/// Additional provider-specific parameters to be sent to the completion model provider
|
||||
pub additional_params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Builder struct for constructing a completion request.
|
||||
pub struct CompletionRequestBuilder<M: CompletionModel> {
|
||||
model: M,
|
||||
prompt: String,
|
||||
|
@ -176,44 +253,52 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sets the preamble for the completion request.
|
||||
pub fn preamble(mut self, preamble: String) -> Self {
|
||||
self.preamble = Some(preamble);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a message to the chat history for the completion request.
|
||||
pub fn message(mut self, message: Message) -> Self {
|
||||
self.chat_history.push(message);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a list of messages to the chat history for the completion request.
|
||||
pub fn messages(self, messages: Vec<Message>) -> Self {
|
||||
messages
|
||||
.into_iter()
|
||||
.fold(self, |builder, msg| builder.message(msg))
|
||||
}
|
||||
|
||||
/// Adds a document to the completion request.
|
||||
pub fn document(mut self, document: Document) -> Self {
|
||||
self.documents.push(document);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a list of documents to the completion request.
|
||||
pub fn documents(self, documents: Vec<Document>) -> Self {
|
||||
documents
|
||||
.into_iter()
|
||||
.fold(self, |builder, doc| builder.document(doc))
|
||||
}
|
||||
|
||||
/// Adds a tool to the completion request.
|
||||
pub fn tool(mut self, tool: ToolDefinition) -> Self {
|
||||
self.tools.push(tool);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a list of tools to the completion request.
|
||||
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
|
||||
tools
|
||||
.into_iter()
|
||||
.fold(self, |builder, tool| builder.tool(tool))
|
||||
}
|
||||
|
||||
/// Adds additional parameters to the completion request.
|
||||
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
|
||||
match self.additional_params {
|
||||
Some(params) => {
|
||||
|
@ -226,21 +311,25 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
self
|
||||
}
|
||||
|
||||
/// Sets the additional parameters for the completion request.
|
||||
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
|
||||
self.additional_params = additional_params;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the temperature for the completion request.
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the temperature for the completion request.
|
||||
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the completion request.
|
||||
pub fn build(self) -> CompletionRequest {
|
||||
CompletionRequest {
|
||||
prompt: self.prompt,
|
||||
|
@ -253,7 +342,8 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<CompletionResponse<M::T>, CompletionError> {
|
||||
/// Sends the completion request to the completion model provider and returns the completion response.
|
||||
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
|
||||
let model = self.model.clone();
|
||||
model.completion(self.build()).await
|
||||
}
|
||||
|
|
|
@ -1,3 +1,44 @@
|
|||
//! This module provides functionality for working with embeddings and embedding models.
|
||||
//! Embeddings are numerical representations of documents or other objects, typically used in
|
||||
//! natural language processing (NLP) tasks such as text classification, information retrieval,
|
||||
//! and document similarity.
|
||||
//!
|
||||
//! The module defines the `EmbeddingModel` trait, which represents an embedding model that can
|
||||
//! generate embeddings for documents. It also provides an implementation of the `EmbeddingsBuilder`
|
||||
//! struct, which allows users to build collections of document embeddings using different embedding
|
||||
//! models and document sources.
|
||||
//!
|
||||
//! The module also defines the `Embedding` struct, which represents a single document embedding,
|
||||
//! and the `DocumentEmbeddings` struct, which represents a document along with its associated
|
||||
//! embeddings. These structs are used to store and manipulate collections of document embeddings.
|
||||
//!
|
||||
//! Finally, the module defines the `EmbeddingError` enum, which represents various errors that
|
||||
//! can occur during embedding generation or processing.
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! ```rust
|
||||
//! use rig::providers::openai::{Client, self};
|
||||
//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder};
|
||||
//!
|
||||
//! // Initialize the OpenAI client
|
||||
//! let openai = Client::new("your-openai-api-key");
|
||||
//!
|
||||
//! // Create an instance of the `text-embedding-ada-002` model
|
||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||
//!
|
||||
//! // Create an embeddings builder and add documents
|
||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model)
|
||||
//! .simple_document("doc1", "This is the first document.")
|
||||
//! .simple_document("doc2", "This is the second document.")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .expect("Failed to build embeddings.");
|
||||
//!
|
||||
//! // Use the generated embeddings
|
||||
//! // ...
|
||||
//! ```
|
||||
|
||||
use std::{cmp::max, collections::HashMap};
|
||||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
@ -28,6 +69,7 @@ pub enum EmbeddingError {
|
|||
ProviderError(String),
|
||||
}
|
||||
|
||||
/// Trait for embedding models that can generate embeddings for documents.
|
||||
pub trait EmbeddingModel: Clone + Sync + Send {
|
||||
const MAX_DOCUMENTS: usize;
|
||||
|
||||
|
@ -56,9 +98,12 @@ pub trait EmbeddingModel: Clone + Sync + Send {
|
|||
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
|
||||
}
|
||||
|
||||
/// Struct that holds a single document and its embedding.
|
||||
#[derive(Clone, Default, Deserialize, Serialize)]
|
||||
pub struct Embedding {
|
||||
/// The document that was embedded
|
||||
pub document: String,
|
||||
/// The embedding vector
|
||||
pub vec: Vec<f64>,
|
||||
}
|
||||
|
||||
|
@ -85,7 +130,12 @@ impl Embedding {
|
|||
}
|
||||
}
|
||||
|
||||
/// Struct to store the document and its embeddings.
|
||||
/// Struct that holds a document and its embeddings.
|
||||
///
|
||||
/// The struct is designed to model any kind of documents that can be serialized to JSON
|
||||
/// (including a simple string). Moreover, it can hold multiple embeddings for the same
|
||||
/// document, thus allowing a large or non-text document to be "ragged" from various
|
||||
/// smaller text documents.
|
||||
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct DocumentEmbeddings {
|
||||
#[serde(rename = "_id")]
|
||||
|
|
|
@ -22,8 +22,9 @@ pub enum ExtractionError {
|
|||
PromptError(#[from] PromptError),
|
||||
}
|
||||
|
||||
/// Extractor for structured data from text
|
||||
pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
|
||||
pub agent: Agent<M>,
|
||||
agent: Agent<M>,
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
|
@ -42,6 +43,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Builder for the Extractor
|
||||
pub struct ExtractorBuilder<
|
||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
|
||||
M: CompletionModel,
|
||||
|
@ -50,7 +52,7 @@ pub struct ExtractorBuilder<
|
|||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
|
||||
ExtractorBuilder<T, M>
|
||||
{
|
||||
pub fn new(model: M) -> Self {
|
||||
|
@ -58,7 +60,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
|||
agent_builder: AgentBuilder::new(model)
|
||||
.preamble("\
|
||||
You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
|
||||
You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n
|
||||
You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
|
||||
Use the `submit` function to submit the structured data.\n\
|
||||
Be sure to fill out every field and ALWAYS CALL THE `submit` function, event with default values!!!.
|
||||
")
|
||||
|
@ -81,6 +83,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
|||
self
|
||||
}
|
||||
|
||||
/// Build the Extractor
|
||||
pub fn build(self) -> Extractor<M, T> {
|
||||
Extractor {
|
||||
agent: self.agent_builder.build(),
|
||||
|
@ -96,13 +99,13 @@ struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
|
|||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("SubmitError")]
|
||||
pub struct SubmitError;
|
||||
struct SubmitError;
|
||||
|
||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> Tool for SubmitTool<T> {
|
||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
|
||||
const NAME: &'static str = "submit";
|
||||
type Error = SubmitError;
|
||||
type Args = String;
|
||||
type Output = String;
|
||||
type Args = T;
|
||||
type Output = T;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
|
@ -113,7 +116,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> Tool for SubmitTool<
|
|||
}
|
||||
}
|
||||
|
||||
async fn call(&self, data: Self::Args) -> Result<String, Self::Error> {
|
||||
async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,27 @@
|
|||
//!
|
||||
//!
|
||||
//!
|
||||
//! Simple example:
|
||||
//! ```
|
||||
//! use rig::{completion::Prompt, providers::openai};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() {
|
||||
//! // Create OpenAI client and model
|
||||
//! let openai_client = openai::Client::from_env();
|
||||
//!
|
||||
//! let gpt4 = openai_client.model("gpt-4").build();
|
||||
//!
|
||||
//! // Prompt the model and print its response
|
||||
//! let response = gpt4
|
||||
//! .prompt("Who are you?")
|
||||
//! .await
|
||||
//! .expect("Failed to prompt GPT-4");
|
||||
//!
|
||||
//! println!("GPT-4: {response}");
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod agent;
|
||||
pub mod cli_chatbot;
|
||||
pub mod completion;
|
||||
|
|
|
@ -1,12 +1,39 @@
|
|||
//! This module contains the implementation of the `Agent` struct and its builder.
|
||||
//!
|
||||
//! The `Agent` struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
|
||||
//! by providing prompts and chat history.
|
||||
//!
|
||||
//! The `AgentBuilder` struct provides a builder pattern for creating instances of the `Agent` struct.
|
||||
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust
|
||||
//! use rig::{completion::Prompt, providers::openai};
|
||||
//!
|
||||
//! let openai_client = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the model
|
||||
//! let model = client.model("gpt-4o")
|
||||
//! .temperature(0.8)
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the model for completions and prompts
|
||||
//! let completion_req_builder = model.completion("Prompt", chat_history).await;
|
||||
//! let chat_response = model.chat("Prompt", chat_history).await;
|
||||
//! ```
|
||||
//!
|
||||
//! For more information on how to use the `Agent` struct and its builder, refer to the documentation of the respective structs and methods.
|
||||
use crate::completion::{
|
||||
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
|
||||
Message, ModelChoice, Prompt, PromptError,
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
|
||||
CompletionResponse, Message, ModelChoice, Prompt, PromptError,
|
||||
};
|
||||
|
||||
/// A model that can be used to prompt completions from a completion model.
|
||||
/// This is the simplest building block for creating an LLM powered application.
|
||||
pub struct Model<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
|
||||
model: M,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
|
@ -27,6 +54,12 @@ impl<M: CompletionModel> Completion<M> for Model<M> {
|
|||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for Model<M> {
|
||||
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
|
||||
self.chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Chat for Model<M> {
|
||||
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
|
||||
match self.completion(prompt, chat_history).await?.send().await? {
|
||||
CompletionResponse {
|
||||
|
@ -43,12 +76,28 @@ impl<M: CompletionModel> Prompt for Model<M> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A builder for creating a model
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, model::ModelBuilder};
|
||||
///
|
||||
/// let openai_client = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4 = openai_client.completion_model("gpt-4");
|
||||
///
|
||||
/// // Configure the model
|
||||
/// let model = ModelBuilder::new(model)
|
||||
/// .temperature(0.8)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct ModelBuilder<M: CompletionModel> {
|
||||
model: M,
|
||||
pub temperature: Option<f64>,
|
||||
temperature: Option<f64>,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> ModelBuilder<M> {
|
||||
/// Create a new model builder
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
|
@ -56,11 +105,19 @@ impl<M: CompletionModel> ModelBuilder<M> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Set the temperature of the model
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature of the model (set to None to use the default temperature of the model)
|
||||
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the model
|
||||
pub fn build(self) -> Model<M> {
|
||||
Model {
|
||||
model: self.model,
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
//! Cohere API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::{providers::cohere, model::ModelBuilder};
|
||||
//!
|
||||
//! ```
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
|
@ -75,7 +82,7 @@ impl Client {
|
|||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync>(
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
|
@ -221,9 +228,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
|||
})
|
||||
.collect())
|
||||
}
|
||||
ApiResponse::Err(error) => {
|
||||
return Err(EmbeddingError::ProviderError(error.message));
|
||||
}
|
||||
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -467,7 +472,7 @@ impl CompletionModel {
|
|||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type T = CompletionResponse;
|
||||
type Response = CompletionResponse;
|
||||
|
||||
async fn completion(
|
||||
&self,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//! OpenAI API client and Rig integration
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
|
@ -24,10 +25,12 @@ pub struct Client {
|
|||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenAI client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENAI_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenAI client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
|
@ -47,32 +50,116 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
/// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
|
||||
/// ```
|
||||
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an embedding builder with the given embedding model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
|
||||
/// .simple_document("doc0", "Hello, world!")
|
||||
/// .simple_document("doc1", "Goodbye, world!")
|
||||
/// .build()
|
||||
/// .await
|
||||
/// .expect("Failed to embed documents");
|
||||
/// ```
|
||||
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
|
||||
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let gpt4 = openai.completion_model(openai::GPT_4);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create a model builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let completion_model = openai.model(openai::GPT_4)
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn model(&self, model: &str) -> ModelBuilder<CompletionModel> {
|
||||
ModelBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let agent = openai.agent(openai::GPT_4)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync>(
|
||||
/// Create an extractor builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let extractor = openai.extractor::<MyStruct>(openai::GPT_4)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
|
@ -374,7 +461,7 @@ impl CompletionModel {
|
|||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type T = CompletionResponse;
|
||||
type Response = CompletionResponse;
|
||||
|
||||
async fn completion(
|
||||
&self,
|
||||
|
|
|
@ -4,8 +4,8 @@ use futures::{stream, StreamExt, TryStreamExt};
|
|||
|
||||
use crate::{
|
||||
completion::{
|
||||
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
|
||||
Document, Message, ModelChoice, Prompt, PromptError,
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
|
||||
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
|
||||
},
|
||||
tool::{Tool, ToolSet, ToolSetError},
|
||||
vector_store::{NoIndex, VectorStoreError, VectorStoreIndex},
|
||||
|
@ -126,6 +126,12 @@ impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Completion<M>
|
|||
}
|
||||
|
||||
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Prompt for RagAgent<M, C, T> {
|
||||
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
|
||||
self.chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Chat for RagAgent<M, C, T> {
|
||||
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
|
||||
match self.completion(prompt, chat_history).await?.send().await? {
|
||||
CompletionResponse {
|
||||
|
@ -146,6 +152,28 @@ impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> RagAgent<M, C
|
|||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a RAG agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, rag_agent::RagAgentBuilder};
|
||||
///
|
||||
/// let openai_client = openai::Client::from_env();
|
||||
///
|
||||
/// let model = openai_client.completion_model("gpt-4");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = RagAgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .static_context("Context document 1")
|
||||
/// .static_context("Context document 2")
|
||||
/// .dynamic_context(2, vector_index)
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct RagAgentBuilder<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
|
|
|
@ -16,12 +16,66 @@ pub enum ToolError {
|
|||
}
|
||||
|
||||
/// Trait that represents a simple LLM tool
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::tool::{ToolSet, Tool};
|
||||
///
|
||||
/// #[derive(Deserialize)]
|
||||
/// struct AddArgs {
|
||||
/// x: i32,
|
||||
/// y: i32,
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Debug, thiserror::Error)]
|
||||
/// #[error("Math error")]
|
||||
/// struct MathError;
|
||||
///
|
||||
/// #[derive(Deserialize, Serialize)]
|
||||
/// struct Adder;
|
||||
///
|
||||
/// impl Tool for Adder {
|
||||
/// const NAME: &'static str = "add";
|
||||
///
|
||||
/// type Error = MathError;
|
||||
/// type Args = AddArgs;
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
/// ToolDefinition {
|
||||
/// name: "add".to_string(),
|
||||
/// description: "Add x and y together".to_string(),
|
||||
/// parameters: json!({
|
||||
/// "type": "object",
|
||||
/// "properties": {
|
||||
/// "x": {
|
||||
/// "type": "number",
|
||||
/// "description": "The first number to add"
|
||||
/// },
|
||||
/// "y": {
|
||||
/// "type": "number",
|
||||
/// "description": "The second number to add"
|
||||
/// }
|
||||
/// }
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
/// let result = args.x + args.y;
|
||||
/// Ok(result)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Tool: Sized + Send + Sync {
|
||||
/// The name of the tool. This name should be unique.
|
||||
const NAME: &'static str;
|
||||
|
||||
/// The error type of the tool.
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
/// The arguments type of the tool.
|
||||
type Args: for<'a> Deserialize<'a> + Send + Sync;
|
||||
/// The output type of the tool.
|
||||
type Output: Serialize;
|
||||
|
||||
/// A method returning the name of the tool.
|
||||
|
@ -180,7 +234,8 @@ pub struct ToolSet {
|
|||
}
|
||||
|
||||
impl ToolSet {
|
||||
pub fn new(tools: Vec<impl ToolDyn + 'static>) -> Self {
|
||||
/// Create a new ToolSet from a list of tools
|
||||
pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
|
||||
let mut toolset = Self::default();
|
||||
tools.into_iter().for_each(|tool| {
|
||||
toolset.add_tool(tool);
|
||||
|
@ -188,19 +243,23 @@ impl ToolSet {
|
|||
toolset
|
||||
}
|
||||
|
||||
/// Create a toolset builder
|
||||
pub fn builder() -> ToolSetBuilder {
|
||||
ToolSetBuilder::default()
|
||||
}
|
||||
|
||||
/// Check if the toolset contains a tool with the given name
|
||||
pub fn contains(&self, toolname: &str) -> bool {
|
||||
self.tools.contains_key(toolname)
|
||||
}
|
||||
|
||||
/// Add a tool to the toolset
|
||||
pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
|
||||
self.tools
|
||||
.insert(tool.name(), ToolType::Simple(Box::new(tool)));
|
||||
}
|
||||
|
||||
/// Merge another toolset into this one
|
||||
pub fn add_tools(&mut self, toolset: ToolSet) {
|
||||
self.tools.extend(toolset.tools);
|
||||
}
|
||||
|
@ -209,9 +268,14 @@ impl ToolSet {
|
|||
self.tools.get(toolname)
|
||||
}
|
||||
|
||||
/// Call a tool with the given name and arguments
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// ```
|
||||
pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
tracing::info!(target: "ai",
|
||||
tracing::info!(target: "rig",
|
||||
"Calling tool {toolname} with args:\n{}",
|
||||
serde_json::to_string_pretty(&args).unwrap_or_else(|_| args.clone())
|
||||
);
|
||||
|
@ -221,6 +285,7 @@ impl ToolSet {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the documents of all the tools in the toolset
|
||||
pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
|
||||
let mut docs = Vec::new();
|
||||
for tool in self.tools.values() {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//! In-memory implementation of a vector store.
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
collections::{BinaryHeap, HashMap},
|
||||
|
@ -9,6 +10,8 @@ use serde::{Deserialize, Serialize};
|
|||
use super::{VectorStore, VectorStoreError, VectorStoreIndex};
|
||||
use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder};
|
||||
|
||||
/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings
|
||||
/// in-memory using a HashMap.
|
||||
#[derive(Clone, Default, Deserialize, Serialize)]
|
||||
pub struct InMemoryVectorStore {
|
||||
/// The embeddings are stored in a HashMap with the document ID as the key.
|
||||
|
@ -159,10 +162,6 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> {
|
|||
}
|
||||
|
||||
impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> {
|
||||
async fn embed_document(&self, document: &str) -> Result<Embedding, VectorStoreError> {
|
||||
Ok(self.model.embed_document(document).await?)
|
||||
}
|
||||
|
||||
async fn top_n_from_query(
|
||||
&self,
|
||||
query: &str,
|
||||
|
@ -208,7 +207,7 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI
|
|||
}
|
||||
|
||||
// Log selected tools with their distances
|
||||
tracing::info!(target: "ai",
|
||||
tracing::info!(target: "rig",
|
||||
"Selected documents: {}",
|
||||
docs.iter()
|
||||
.map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance))
|
||||
|
|
|
@ -17,36 +17,38 @@ pub enum VectorStoreError {
|
|||
DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
/// Trait for vector stores
|
||||
pub trait VectorStore: Send + Sync {
|
||||
/// Query type for the vector store
|
||||
type Q;
|
||||
|
||||
/// Add a list of documents to the vector store
|
||||
fn add_documents(
|
||||
&mut self,
|
||||
documents: Vec<DocumentEmbeddings>,
|
||||
) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
|
||||
|
||||
/// Get the embeddings of a document by its id
|
||||
fn get_document_embeddings(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send;
|
||||
|
||||
/// Get the document by its id and deserialize it into the given type
|
||||
fn get_document<T: for<'a> Deserialize<'a>>(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> impl std::future::Future<Output = Result<Option<T>, VectorStoreError>> + Send;
|
||||
|
||||
/// Get the document by a query and deserialize it into the given type
|
||||
fn get_document_by_query(
|
||||
&self,
|
||||
query: Self::Q,
|
||||
) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send;
|
||||
}
|
||||
|
||||
/// Trait for vector store indexes
|
||||
pub trait VectorStoreIndex: Send + Sync {
|
||||
fn embed_document(
|
||||
&self,
|
||||
document: &str,
|
||||
) -> impl std::future::Future<Output = Result<Embedding, VectorStoreError>> + Send;
|
||||
|
||||
/// Get the top n documents based on the distance to the given embedding.
|
||||
/// The distance is calculated as the cosine distance between the prompt and
|
||||
/// the document embedding.
|
||||
|
@ -135,10 +137,6 @@ pub trait VectorStoreIndex: Send + Sync {
|
|||
pub struct NoIndex;
|
||||
|
||||
impl VectorStoreIndex for NoIndex {
|
||||
async fn embed_document(&self, _document: &str) -> Result<Embedding, VectorStoreError> {
|
||||
Ok(Embedding::default())
|
||||
}
|
||||
|
||||
async fn top_n_from_query(
|
||||
&self,
|
||||
_query: &str,
|
||||
|
|
|
@ -114,10 +114,6 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
|
|||
}
|
||||
|
||||
impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M> {
|
||||
async fn embed_document(&self, document: &str) -> Result<Embedding, VectorStoreError> {
|
||||
Ok(self.model.embed_document(document).await?)
|
||||
}
|
||||
|
||||
async fn top_n_from_query(
|
||||
&self,
|
||||
query: &str,
|
||||
|
@ -166,7 +162,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV
|
|||
results.push((score, document));
|
||||
}
|
||||
|
||||
tracing::info!(target: "ai",
|
||||
tracing::info!(target: "rig",
|
||||
"Selected documents: {}",
|
||||
results.iter()
|
||||
.map(|(distance, doc)| format!("{} ({})", doc.id, distance))
|
||||
|
|
Loading…
Reference in New Issue