mirror of https://github.com/0xplaygrounds/rig
560 lines
20 KiB
Rust
560 lines
20 KiB
Rust
//! This module provides functionality for working with completion models.
|
|
//! It provides traits, structs, and enums for generating completion requests,
|
|
//! handling completion responses, and defining completion models.
|
|
//!
|
|
//! The main traits defined in this module are:
|
|
//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
|
|
//! - [Chat]: Defines a high-level LLM chat interface with chat history.
|
|
//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
|
|
//! - [CompletionModel]: Defines a completion model that can be used to generate completion
|
|
//! responses from requests.
|
|
//!
|
|
//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
|
|
//! to interact with LLM models. Moreover, it is good practice to implement one of these
|
|
//! traits for composite agents that use multiple LLM models to generate responses.
|
|
//!
|
|
//! The [Completion] trait defines a lower level interface that is useful when the user want
|
|
//! to further customize the request before sending it to the completion model provider.
|
|
//!
|
|
//! The [CompletionModel] trait is meant to act as the interface between providers and
|
|
//! the library. It defines the methods that need to be implemented by the user to define
|
|
//! a custom base completion model (i.e.: a private or third party LLM provider).
|
|
//!
|
|
//! The module also provides various structs and enums for representing generic completion requests,
|
|
//! responses, and errors.
|
|
//!
|
|
//! Example Usage:
|
|
//! ```rust
|
|
//! use rig::providers::openai::{Client, self};
|
|
//! use rig::completion::*;
|
|
//!
|
|
//! // Initialize the OpenAI client and a completion model
|
|
//! let openai = Client::new("your-openai-api-key");
|
|
//!
|
|
//! let gpt_4 = openai.completion_model(openai::GPT_4);
|
|
//!
|
|
//! // Create the completion request
|
|
//! let request = gpt_4.completion_request("Who are you?")
|
|
//! .preamble("\
|
|
//! You are Marvin, an extremely smart but depressed robot who is \
|
|
//! nonetheless helpful towards humanity.\
|
|
//! ")
|
|
//! .temperature(0.5)
|
|
//! .build();
|
|
//!
|
|
//! // Send the completion request and get the completion response
|
|
//! let response = gpt_4.completion(request)
|
|
//! .await
|
|
//! .expect("Failed to get completion response");
|
|
//!
|
|
//! // Handle the completion response
|
|
//! match completion_response.choice {
|
|
//! ModelChoice::Message(message) => {
|
|
//! // Handle the completion response as a message
|
|
//! println!("Received message: {}", message);
|
|
//! }
|
|
//! ModelChoice::ToolCall(tool_name, tool_params) => {
|
|
//! // Handle the completion response as a tool call
|
|
//! println!("Received tool call: {} {:?}", tool_name, tool_params);
|
|
//! }
|
|
//! }
|
|
//! ```
|
|
//!
|
|
//! For more information on how to use the completion functionality, refer to the documentation of
|
|
//! the individual traits, structs, and enums defined in this module.
|
|
use std::collections::HashMap;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use thiserror::Error;
|
|
|
|
use crate::streaming::{StreamingCompletionModel, StreamingCompletionResponse};
|
|
use crate::OneOrMany;
|
|
use crate::{
|
|
json_utils,
|
|
message::{Message, UserContent},
|
|
tool::ToolSetError,
|
|
};
|
|
|
|
use super::message::AssistantContent;
|
|
|
|
// Errors
|
|
#[derive(Debug, Error)]
|
|
pub enum CompletionError {
|
|
/// Http error (e.g.: connection error, timeout, etc.)
|
|
#[error("HttpError: {0}")]
|
|
HttpError(#[from] reqwest::Error),
|
|
|
|
/// Json error (e.g.: serialization, deserialization)
|
|
#[error("JsonError: {0}")]
|
|
JsonError(#[from] serde_json::Error),
|
|
|
|
/// Error building the completion request
|
|
#[error("RequestError: {0}")]
|
|
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
|
|
|
|
/// Error parsing the completion response
|
|
#[error("ResponseError: {0}")]
|
|
ResponseError(String),
|
|
|
|
/// Error returned by the completion model provider
|
|
#[error("ProviderError: {0}")]
|
|
ProviderError(String),
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum PromptError {
|
|
#[error("CompletionError: {0}")]
|
|
CompletionError(#[from] CompletionError),
|
|
|
|
#[error("ToolCallError: {0}")]
|
|
ToolError(#[from] ToolSetError),
|
|
}
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct Document {
|
|
pub id: String,
|
|
pub text: String,
|
|
#[serde(flatten)]
|
|
pub additional_props: HashMap<String, String>,
|
|
}
|
|
|
|
impl std::fmt::Display for Document {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
f,
|
|
concat!("<file id: {}>\n", "{}\n", "</file>\n"),
|
|
self.id,
|
|
if self.additional_props.is_empty() {
|
|
self.text.clone()
|
|
} else {
|
|
let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
|
|
sorted_props.sort_by(|a, b| a.0.cmp(b.0));
|
|
let metadata = sorted_props
|
|
.iter()
|
|
.map(|(k, v)| format!("{}: {:?}", k, v))
|
|
.collect::<Vec<_>>()
|
|
.join(" ");
|
|
format!("<metadata {} />\n{}", metadata, self.text)
|
|
}
|
|
)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct ToolDefinition {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: serde_json::Value,
|
|
}
|
|
|
|
// ================================================================
|
|
// Implementations
|
|
// ================================================================
|
|
/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
|
|
pub trait Prompt: Send + Sync {
|
|
/// Send a simple prompt to the underlying completion model.
|
|
///
|
|
/// If the completion model's response is a message, then it is returned as a string.
|
|
///
|
|
/// If the completion model's response is a tool call, then the tool is called and
|
|
/// the result is returned as a string.
|
|
///
|
|
/// If the tool does not exist, or the tool call fails, then an error is returned.
|
|
fn prompt(
|
|
&self,
|
|
prompt: impl Into<Message> + Send,
|
|
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
|
}
|
|
|
|
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
|
|
pub trait Chat: Send + Sync {
|
|
/// Send a prompt with optional chat history to the underlying completion model.
|
|
///
|
|
/// If the completion model's response is a message, then it is returned as a string.
|
|
///
|
|
/// If the completion model's response is a tool call, then the tool is called and the result
|
|
/// is returned as a string.
|
|
///
|
|
/// If the tool does not exist, or the tool call fails, then an error is returned.
|
|
fn chat(
|
|
&self,
|
|
prompt: impl Into<Message> + Send,
|
|
chat_history: Vec<Message>,
|
|
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
|
}
|
|
|
|
/// Trait defining a low-level LLM completion interface
|
|
pub trait Completion<M: CompletionModel> {
|
|
/// Generates a completion request builder for the given `prompt` and `chat_history`.
|
|
/// This function is meant to be called by the user to further customize the
|
|
/// request at prompt time before sending it.
|
|
///
|
|
/// ❗IMPORTANT: The type that implements this trait might have already
|
|
/// populated fields in the builder (the exact fields depend on the type).
|
|
/// For fields that have already been set by the model, calling the corresponding
|
|
/// method on the builder will overwrite the value set by the model.
|
|
///
|
|
/// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
|
|
/// contain the `preamble` provided when creating the agent.
|
|
fn completion(
|
|
&self,
|
|
prompt: impl Into<Message> + Send,
|
|
chat_history: Vec<Message>,
|
|
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
|
|
}
|
|
|
|
/// General completion response struct that contains the high-level completion choice
|
|
/// and the raw response. The completion choice contains one or more assistant content.
|
|
#[derive(Debug)]
|
|
pub struct CompletionResponse<T> {
|
|
/// The completion choice (represented by one or more assistant message content)
|
|
/// returned by the completion model provider
|
|
pub choice: OneOrMany<AssistantContent>,
|
|
/// The raw response returned by the completion model provider
|
|
pub raw_response: T,
|
|
}
|
|
|
|
/// Trait defining a completion model that can be used to generate completion responses.
|
|
/// This trait is meant to be implemented by the user to define a custom completion model,
|
|
/// either from a third party provider (e.g.: OpenAI) or a local model.
|
|
pub trait CompletionModel: Clone + Send + Sync {
|
|
/// The raw response type returned by the underlying completion model.
|
|
type Response: Send + Sync;
|
|
|
|
/// Generates a completion response for the given completion request.
|
|
fn completion(
|
|
&self,
|
|
request: CompletionRequest,
|
|
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
|
|
+ Send;
|
|
|
|
/// Generates a completion request builder for the given `prompt`.
|
|
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
|
|
CompletionRequestBuilder::new(self.clone(), prompt)
|
|
}
|
|
}
|
|
|
|
/// Struct representing a general completion request that can be sent to a completion model provider.
|
|
pub struct CompletionRequest {
|
|
/// The prompt to be sent to the completion model provider
|
|
pub prompt: Message,
|
|
/// The preamble to be sent to the completion model provider
|
|
pub preamble: Option<String>,
|
|
/// The chat history to be sent to the completion model provider
|
|
pub chat_history: Vec<Message>,
|
|
/// The documents to be sent to the completion model provider
|
|
pub documents: Vec<Document>,
|
|
/// The tools to be sent to the completion model provider
|
|
pub tools: Vec<ToolDefinition>,
|
|
/// The temperature to be sent to the completion model provider
|
|
pub temperature: Option<f64>,
|
|
/// The max tokens to be sent to the completion model provider
|
|
pub max_tokens: Option<u64>,
|
|
/// Additional provider-specific parameters to be sent to the completion model provider
|
|
pub additional_params: Option<serde_json::Value>,
|
|
}
|
|
|
|
impl CompletionRequest {
|
|
pub fn prompt_with_context(&self) -> Message {
|
|
let mut new_prompt = self.prompt.clone();
|
|
if let Message::User { ref mut content } = new_prompt {
|
|
if !self.documents.is_empty() {
|
|
let attachments = self
|
|
.documents
|
|
.iter()
|
|
.map(|doc| doc.to_string())
|
|
.collect::<Vec<_>>()
|
|
.join("");
|
|
let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
|
|
let mut new_content = vec![UserContent::text(formatted_content)];
|
|
new_content.extend(content.clone());
|
|
*content = OneOrMany::many(new_content).expect("This has more than 1 item");
|
|
}
|
|
}
|
|
new_prompt
|
|
}
|
|
}
|
|
|
|
/// Builder struct for constructing a completion request.
|
|
///
|
|
/// Example usage:
|
|
/// ```rust
|
|
/// use rig::{
|
|
/// providers::openai::{Client, self},
|
|
/// completion::CompletionRequestBuilder,
|
|
/// };
|
|
///
|
|
/// let openai = Client::new("your-openai-api-key");
|
|
/// let model = openai.completion_model(openai::GPT_4O).build();
|
|
///
|
|
/// // Create the completion request and execute it separately
|
|
/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
|
|
/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
|
|
/// .temperature(0.5)
|
|
/// .build();
|
|
///
|
|
/// let response = model.completion(request)
|
|
/// .await
|
|
/// .expect("Failed to get completion response");
|
|
/// ```
|
|
///
|
|
/// Alternatively, you can execute the completion request directly from the builder:
|
|
/// ```rust
|
|
/// use rig::{
|
|
/// providers::openai::{Client, self},
|
|
/// completion::CompletionRequestBuilder,
|
|
/// };
|
|
///
|
|
/// let openai = Client::new("your-openai-api-key");
|
|
/// let model = openai.completion_model(openai::GPT_4O).build();
|
|
///
|
|
/// // Create the completion request and execute it directly
|
|
/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
|
|
/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
|
|
/// .temperature(0.5)
|
|
/// .send()
|
|
/// .await
|
|
/// .expect("Failed to get completion response");
|
|
/// ```
|
|
///
|
|
/// Note: It is usually unnecessary to create a completion request builder directly.
|
|
/// Instead, use the [CompletionModel::completion_request] method.
|
|
pub struct CompletionRequestBuilder<M: CompletionModel> {
|
|
model: M,
|
|
prompt: Message,
|
|
preamble: Option<String>,
|
|
chat_history: Vec<Message>,
|
|
documents: Vec<Document>,
|
|
tools: Vec<ToolDefinition>,
|
|
temperature: Option<f64>,
|
|
max_tokens: Option<u64>,
|
|
additional_params: Option<serde_json::Value>,
|
|
}
|
|
|
|
impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
|
|
Self {
|
|
model,
|
|
prompt: prompt.into(),
|
|
preamble: None,
|
|
chat_history: Vec::new(),
|
|
documents: Vec::new(),
|
|
tools: Vec::new(),
|
|
temperature: None,
|
|
max_tokens: None,
|
|
additional_params: None,
|
|
}
|
|
}
|
|
|
|
/// Sets the preamble for the completion request.
|
|
pub fn preamble(mut self, preamble: String) -> Self {
|
|
self.preamble = Some(preamble);
|
|
self
|
|
}
|
|
|
|
/// Adds a message to the chat history for the completion request.
|
|
pub fn message(mut self, message: Message) -> Self {
|
|
self.chat_history.push(message);
|
|
self
|
|
}
|
|
|
|
/// Adds a list of messages to the chat history for the completion request.
|
|
pub fn messages(self, messages: Vec<Message>) -> Self {
|
|
messages
|
|
.into_iter()
|
|
.fold(self, |builder, msg| builder.message(msg))
|
|
}
|
|
|
|
/// Adds a document to the completion request.
|
|
pub fn document(mut self, document: Document) -> Self {
|
|
self.documents.push(document);
|
|
self
|
|
}
|
|
|
|
/// Adds a list of documents to the completion request.
|
|
pub fn documents(self, documents: Vec<Document>) -> Self {
|
|
documents
|
|
.into_iter()
|
|
.fold(self, |builder, doc| builder.document(doc))
|
|
}
|
|
|
|
/// Adds a tool to the completion request.
|
|
pub fn tool(mut self, tool: ToolDefinition) -> Self {
|
|
self.tools.push(tool);
|
|
self
|
|
}
|
|
|
|
/// Adds a list of tools to the completion request.
|
|
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
|
|
tools
|
|
.into_iter()
|
|
.fold(self, |builder, tool| builder.tool(tool))
|
|
}
|
|
|
|
/// Adds additional parameters to the completion request.
|
|
/// This can be used to set additional provider-specific parameters. For example,
|
|
/// Cohere's completion models accept a `connectors` parameter that can be used to
|
|
/// specify the data connectors used by Cohere when executing the completion
|
|
/// (see `examples/cohere_connectors.rs`).
|
|
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
|
|
match self.additional_params {
|
|
Some(params) => {
|
|
self.additional_params = Some(json_utils::merge(params, additional_params));
|
|
}
|
|
None => {
|
|
self.additional_params = Some(additional_params);
|
|
}
|
|
}
|
|
self
|
|
}
|
|
|
|
/// Sets the additional parameters for the completion request.
|
|
/// This can be used to set additional provider-specific parameters. For example,
|
|
/// Cohere's completion models accept a `connectors` parameter that can be used to
|
|
/// specify the data connectors used by Cohere when executing the completion
|
|
/// (see `examples/cohere_connectors.rs`).
|
|
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
|
|
self.additional_params = additional_params;
|
|
self
|
|
}
|
|
|
|
/// Sets the temperature for the completion request.
|
|
pub fn temperature(mut self, temperature: f64) -> Self {
|
|
self.temperature = Some(temperature);
|
|
self
|
|
}
|
|
|
|
/// Sets the temperature for the completion request.
|
|
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
|
|
self.temperature = temperature;
|
|
self
|
|
}
|
|
|
|
/// Sets the max tokens for the completion request.
|
|
/// Note: This is required if using Anthropic
|
|
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
|
self.max_tokens = Some(max_tokens);
|
|
self
|
|
}
|
|
|
|
/// Sets the max tokens for the completion request.
|
|
/// Note: This is required if using Anthropic
|
|
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
|
|
self.max_tokens = max_tokens;
|
|
self
|
|
}
|
|
|
|
/// Builds the completion request.
|
|
pub fn build(self) -> CompletionRequest {
|
|
CompletionRequest {
|
|
prompt: self.prompt,
|
|
preamble: self.preamble,
|
|
chat_history: self.chat_history,
|
|
documents: self.documents,
|
|
tools: self.tools,
|
|
temperature: self.temperature,
|
|
max_tokens: self.max_tokens,
|
|
additional_params: self.additional_params,
|
|
}
|
|
}
|
|
|
|
/// Sends the completion request to the completion model provider and returns the completion response.
|
|
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
|
|
let model = self.model.clone();
|
|
model.completion(self.build()).await
|
|
}
|
|
}
|
|
|
|
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
|
/// Stream the completion request
|
|
pub async fn stream(
|
|
self,
|
|
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
|
|
let model = self.model.clone();
|
|
model.stream(self.build()).await
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::OneOrMany;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_document_display_without_metadata() {
|
|
let doc = Document {
|
|
id: "123".to_string(),
|
|
text: "This is a test document.".to_string(),
|
|
additional_props: HashMap::new(),
|
|
};
|
|
|
|
let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
|
|
assert_eq!(format!("{}", doc), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn test_document_display_with_metadata() {
|
|
let mut additional_props = HashMap::new();
|
|
additional_props.insert("author".to_string(), "John Doe".to_string());
|
|
additional_props.insert("length".to_string(), "42".to_string());
|
|
|
|
let doc = Document {
|
|
id: "123".to_string(),
|
|
text: "This is a test document.".to_string(),
|
|
additional_props,
|
|
};
|
|
|
|
let expected = concat!(
|
|
"<file id: 123>\n",
|
|
"<metadata author: \"John Doe\" length: \"42\" />\n",
|
|
"This is a test document.\n",
|
|
"</file>\n"
|
|
);
|
|
assert_eq!(format!("{}", doc), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn test_prompt_with_context_with_documents() {
|
|
let doc1 = Document {
|
|
id: "doc1".to_string(),
|
|
text: "Document 1 text.".to_string(),
|
|
additional_props: HashMap::new(),
|
|
};
|
|
|
|
let doc2 = Document {
|
|
id: "doc2".to_string(),
|
|
text: "Document 2 text.".to_string(),
|
|
additional_props: HashMap::new(),
|
|
};
|
|
|
|
let request = CompletionRequest {
|
|
prompt: "What is the capital of France?".into(),
|
|
preamble: None,
|
|
chat_history: Vec::new(),
|
|
documents: vec![doc1, doc2],
|
|
tools: Vec::new(),
|
|
temperature: None,
|
|
max_tokens: None,
|
|
additional_params: None,
|
|
};
|
|
|
|
let expected = Message::User {
|
|
content: OneOrMany::many(vec![
|
|
UserContent::text(concat!(
|
|
"<attachments>\n",
|
|
"<file id: doc1>\nDocument 1 text.\n</file>\n",
|
|
"<file id: doc2>\nDocument 2 text.\n</file>\n",
|
|
"</attachments>"
|
|
)),
|
|
UserContent::text("What is the capital of France?"),
|
|
])
|
|
.expect("This has more than 1 item"),
|
|
};
|
|
|
|
request.prompt_with_context();
|
|
|
|
assert_eq!(request.prompt_with_context(), expected);
|
|
}
|
|
}
|