mirror of https://github.com/0xplaygrounds/rig
refactor: blow up agent and introduce typestate
This commit is contained in:
parent
4b38294264
commit
44aa8e9f13
|
@ -0,0 +1,254 @@
|
|||
use rig::{
|
||||
agent::Agent,
|
||||
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||
extractor::Extractor,
|
||||
message::Message,
|
||||
providers::anthropic,
|
||||
tool::Tool,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
const CHAIN_OF_THOUGHT_PROMPT: &str = "
|
||||
You are an assistant that extracts reasoning steps from a given prompt.
|
||||
Do not return text, only return a tool call.
|
||||
";
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
|
||||
struct ChainOfThoughtSteps {
|
||||
steps: Vec<String>,
|
||||
}
|
||||
|
||||
struct ReasoningAgent<M: CompletionModel> {
|
||||
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
|
||||
executor: Agent<M>,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
|
||||
#[allow(refining_impl_trait)]
|
||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||
let prompt: Message = prompt.into();
|
||||
let mut chat_history = vec![prompt.clone()];
|
||||
let extracted = self
|
||||
.chain_of_thought_extractor
|
||||
.extract(prompt)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Extraction error: {:?}", e);
|
||||
CompletionError::ProviderError("".into())
|
||||
})?;
|
||||
|
||||
if extracted.steps.is_empty() {
|
||||
return Ok("No reasoning steps provided.".into());
|
||||
}
|
||||
|
||||
let mut reasoning_prompt = String::new();
|
||||
for (i, step) in extracted.steps.iter().enumerate() {
|
||||
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.executor
|
||||
.prompt(reasoning_prompt.as_str())
|
||||
.with_history(&mut chat_history)
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
// Create OpenAI client
|
||||
let openai_client = anthropic::Client::from_env();
|
||||
|
||||
let agent = ReasoningAgent {
|
||||
chain_of_thought_extractor: openai_client
|
||||
.extractor(anthropic::CLAUDE_3_5_SONNET)
|
||||
.preamble(CHAIN_OF_THOUGHT_PROMPT)
|
||||
.build(),
|
||||
|
||||
executor: openai_client
|
||||
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||
.preamble(
|
||||
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||
Follow these instructions closely.
|
||||
1. Consider the user's request carefully and identify the core elements of the request.
|
||||
2. Select which tool among those made available to you is appropriate given the context.
|
||||
3. This is very important: never perform the operation yourself.
|
||||
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
|
||||
"
|
||||
)
|
||||
.tool(Add)
|
||||
.tool(Subtract)
|
||||
.tool(Multiply)
|
||||
.tool(Divide)
|
||||
.build(),
|
||||
};
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let result = agent.prompt("Calculate x for the equation: `20x + 23 = 400x / (1 - x)`").await?;
|
||||
|
||||
println!("\n\nReasoning Agent: {}", result);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Add;
|
||||
impl Tool for Add {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "add",
|
||||
"description": "Add x and y together",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct Multiply;
|
||||
impl Tool for Multiply {
|
||||
const NAME: &'static str = "multiply";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "multiply",
|
||||
"description": "Compute the product of x and y (i.e.: x * y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first factor in the product"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second factor in the product"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x * args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct Divide;
|
||||
impl Tool for Divide {
|
||||
const NAME: &'static str = "divide";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "divide",
|
||||
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The Dividend of the division. The number being divided"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The Divisor of the division. The number by which the dividend is being divided"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x / args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -1,684 +0,0 @@
|
|||
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||
//!
|
||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||
//!
|
||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||
//! context documents and tools.
|
||||
//!
|
||||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
||||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
||||
//! be used for generating chat completions.
|
||||
//!
|
||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::{Chat, Completion, Prompt},
|
||||
//! providers::openai,
|
||||
//! };
|
||||
//!
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the agent
|
||||
//! let agent = openai.agent("gpt-4o")
|
||||
//! .preamble("System prompt")
|
||||
//! .context("Context document 1")
|
||||
//! .context("Context document 2")
|
||||
//! .tool(tool1)
|
||||
//! .tool(tool2)
|
||||
//! .temperature(0.8)
|
||||
//! .additional_params(json!({"foo": "bar"}))
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the agent for completions and prompts
|
||||
//! // Generate a chat completion response from a prompt and chat history
|
||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to chat with Agent");
|
||||
//!
|
||||
//! // Generate a prompt completion response from a simple prompt
|
||||
//! let chat_response = agent.prompt("Prompt")
|
||||
//! .await
|
||||
//! .expect("Failed to prompt the Agent");
|
||||
//!
|
||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||
//! // model parameters, etc.), but these can be overwritten.
|
||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to create completion request builder");
|
||||
//!
|
||||
//! let response = completion_req_builder
|
||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||
//! .send()
|
||||
//! .await
|
||||
//! .expect("Failed to send completion request");
|
||||
//! ```
|
||||
//!
|
||||
//! RAG Agent example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::Prompt,
|
||||
//! embeddings::EmbeddingsBuilder,
|
||||
//! providers::openai,
|
||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||
//! };
|
||||
//!
|
||||
//! // Initialize OpenAI client
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Initialize OpenAI embedding model
|
||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||
//!
|
||||
//! // Create vector store, compute embeddings and load them in the store
|
||||
//! let mut vector_store = InMemoryVectorStore::default();
|
||||
//!
|
||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .expect("Failed to build embeddings");
|
||||
//!
|
||||
//! vector_store.add_documents(embeddings)
|
||||
//! .await
|
||||
//! .expect("Failed to add documents");
|
||||
//!
|
||||
//! // Create vector store index
|
||||
//! let index = vector_store.index(embedding_model);
|
||||
//!
|
||||
//! let agent = openai.agent(openai::GPT_4O)
|
||||
//! .preamble("
|
||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||
//! You will find additional non-standard word definitions that could be useful below.
|
||||
//! ")
|
||||
//! .dynamic_context(1, index)
|
||||
//! .build();
|
||||
//!
|
||||
//! // Prompt the agent and print the response
|
||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||
//! .expect("Failed to prompt the agent");
|
||||
//! ```
|
||||
use std::{collections::HashMap, future::IntoFuture};
|
||||
|
||||
use futures::{future::BoxFuture, stream, FutureExt, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
message::{AssistantContent, UserContent},
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
tool::{Tool, ToolSet, ToolSetError},
|
||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{completion::Prompt, providers::openai};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let comedian_agent = openai
|
||||
/// .agent("gpt-4o")
|
||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
/// .temperature(0.9)
|
||||
/// .build();
|
||||
///
|
||||
/// let response = comedian_agent.prompt("Entertain me!")
|
||||
/// .await
|
||||
/// .expect("Failed to prompt the agent");
|
||||
/// ```
|
||||
pub struct Agent<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: String,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (identified by their name)
|
||||
static_tools: Vec<String>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Actual tool implementations
|
||||
pub tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||
async fn completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
let prompt = prompt.into();
|
||||
let rag_text = prompt.rag_text().clone();
|
||||
|
||||
let completion_request = self
|
||||
.model
|
||||
.completion_request(prompt)
|
||||
.preamble(self.preamble.clone())
|
||||
.messages(chat_history)
|
||||
.temperature_opt(self.temperature)
|
||||
.max_tokens_opt(self.max_tokens)
|
||||
.additional_params_opt(self.additional_params.clone())
|
||||
.documents(self.static_context.clone());
|
||||
|
||||
let agent = match &rag_text {
|
||||
Some(text) => {
|
||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id, doc)| {
|
||||
// Pretty print the document if possible for better readability
|
||||
let text = serde_json::to_string_pretty(&doc)
|
||||
.unwrap_or_else(|_| doc.to_string());
|
||||
|
||||
Document {
|
||||
id,
|
||||
text,
|
||||
additional_props: HashMap::new(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
acc.extend(docs);
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n_ids(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id)| id)
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
for doc in docs {
|
||||
if let Some(tool) = self.tools.get(&doc) {
|
||||
acc.push(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||
}
|
||||
}
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
Some(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request
|
||||
.documents(dynamic_context)
|
||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||
}
|
||||
None => {
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
// TODO: tool definitions should likely take an `Option<String>`
|
||||
Some(tool.definition("".into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request.tools(static_tools)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
}
|
||||
|
||||
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
||||
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
||||
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
||||
//
|
||||
// References:
|
||||
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||
PromptRequest::new(self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||
PromptRequest::new(*self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Chat for Agent<M> {
|
||||
async fn chat(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<String, PromptError> {
|
||||
let mut cloned_history = chat_history.clone();
|
||||
PromptRequest::new(self, prompt)
|
||||
.with_history(&mut cloned_history)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for creating an agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = AgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .context("Context document 1")
|
||||
/// .context("Context document 2")
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct AgentBuilder<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: Option<String>,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (by name)
|
||||
static_tools: Vec<String>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Actual tool implementations
|
||||
tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> AgentBuilder<M> {
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
preamble: None,
|
||||
static_context: vec![],
|
||||
static_tools: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
dynamic_context: vec![],
|
||||
dynamic_tools: vec![],
|
||||
tools: ToolSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the system prompt
|
||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||
self.preamble = Some(preamble.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Append to the preamble of the agent
|
||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||
self.preamble = Some(format!(
|
||||
"{}\n{}",
|
||||
self.preamble.unwrap_or_else(|| "".into()),
|
||||
doc
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static context document to the agent
|
||||
pub fn context(mut self, doc: &str) -> Self {
|
||||
self.static_context.push(Document {
|
||||
id: format!("static_doc_{}", self.static_context.len()),
|
||||
text: doc.into(),
|
||||
additional_props: HashMap::new(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static tool to the agent
|
||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||
let toolname = tool.name();
|
||||
self.tools.add_tool(tool);
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
// Add an MCP tool to the agent
|
||||
#[cfg(feature = "mcp")]
|
||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||
mut self,
|
||||
tool: mcp_core::types::Tool,
|
||||
client: mcp_core::client::Client<T>,
|
||||
) -> Self {
|
||||
let toolname = tool.name.clone();
|
||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||
/// dynamic context will be inserted in the request.
|
||||
pub fn dynamic_context(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||
) -> Self {
|
||||
self.dynamic_context
|
||||
.push((sample, Box::new(dynamic_context)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||
/// dynamic toolset will be inserted in the request.
|
||||
pub fn dynamic_tools(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||
toolset: ToolSet,
|
||||
) -> Self {
|
||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||
self.tools.add_tools(toolset);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature of the model
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of tokens for the completion
|
||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set additional parameters to be passed to the model
|
||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the agent
|
||||
pub fn build(self) -> Agent<M> {
|
||||
Agent {
|
||||
model: self.model,
|
||||
preamble: self.preamble.unwrap_or_default(),
|
||||
static_context: self.static_context,
|
||||
static_tools: self.static_tools,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params,
|
||||
dynamic_context: self.dynamic_context,
|
||||
dynamic_tools: self.dynamic_tools,
|
||||
tools: self.tools,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for creating prompt requests with customizable options.
|
||||
/// Uses generics to track which options have been set during the build process.
|
||||
pub struct PromptRequest<'c, 'a, M: CompletionModel> {
|
||||
/// The prompt message to send to the model
|
||||
prompt: Message,
|
||||
/// Optional chat history to include with the prompt
|
||||
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
||||
chat_history: Option<&'c mut Vec<Message>>,
|
||||
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
||||
max_depth: usize,
|
||||
/// The agent to use for execution
|
||||
agent: &'a Agent<M>,
|
||||
}
|
||||
|
||||
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
|
||||
/// Create a new PromptRequest with the given prompt and model
|
||||
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
|
||||
Self {
|
||||
prompt: prompt.into(),
|
||||
chat_history: None,
|
||||
max_depth: 0,
|
||||
agent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
|
||||
/// Set the maximum depth for multi-turn conversations
|
||||
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: self.chat_history,
|
||||
max_depth: depth,
|
||||
agent: self.agent,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add chat history to the prompt request
|
||||
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: Some(history),
|
||||
max_depth: self.max_depth,
|
||||
agent: self.agent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
|
||||
/// future, we should be able to use `impl Future<...>` directly via the associated type.
|
||||
///
|
||||
/// Ref: https://github.com/rust-lang/rust/issues/63063
|
||||
impl<'c: 'a, 'a, M: CompletionModel + 'c> IntoFuture for PromptRequest<'c, 'a, M> {
|
||||
type Output = Result<String, PromptError>;
|
||||
type IntoFuture = BoxFuture<'a, Self::Output>;
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
self.send().boxed()
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation for Agent
|
||||
impl<M: CompletionModel> PromptRequest<'_, '_, M> {
|
||||
async fn send(self) -> Result<String, PromptError> {
|
||||
let agent = self.agent;
|
||||
let mut prompt = self.prompt;
|
||||
let chat_history = if let Some(history) = self.chat_history {
|
||||
history
|
||||
} else {
|
||||
&mut Vec::new()
|
||||
};
|
||||
|
||||
let mut current_max_depth = 0;
|
||||
while current_max_depth <= self.max_depth {
|
||||
current_max_depth += 1;
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!(
|
||||
"Current conversation depth: {}/{}",
|
||||
current_max_depth,
|
||||
self.max_depth
|
||||
);
|
||||
}
|
||||
|
||||
let resp = agent
|
||||
.completion(prompt.clone(), chat_history.to_vec())
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
chat_history.push(prompt);
|
||||
|
||||
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
||||
.choice
|
||||
.iter()
|
||||
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
||||
|
||||
chat_history.push(Message::Assistant {
|
||||
content: resp.choice.clone(),
|
||||
});
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let merged_texts = texts
|
||||
.into_iter()
|
||||
.filter_map(|content| {
|
||||
if let AssistantContent::Text(text) = content {
|
||||
Some(text.text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
||||
}
|
||||
|
||||
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
||||
return Ok(merged_texts);
|
||||
}
|
||||
|
||||
let tool_content = stream::iter(tool_calls)
|
||||
.then(async |choice| {
|
||||
if let AssistantContent::ToolCall(tool_call) = choice {
|
||||
let output = agent
|
||||
.tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await?;
|
||||
Ok(UserContent::tool_result(
|
||||
tool_call.id.clone(),
|
||||
OneOrMany::one(output.into()),
|
||||
))
|
||||
} else {
|
||||
unreachable!(
|
||||
"This should never happen as we already filtered for `ToolCall`"
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
prompt = Message::User {
|
||||
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
||||
};
|
||||
}
|
||||
|
||||
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
||||
Err(PromptError::MaxDepthError {
|
||||
max_depth: self.max_depth,
|
||||
chat_history: chat_history.clone(),
|
||||
prompt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||
async fn stream_completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
// Reuse the existing completion implementation to build the request
|
||||
// This ensures streaming and non-streaming use the same request building logic
|
||||
self.completion(prompt, chat_history).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
tool::ToolSet,
|
||||
vector_store::VectorStoreError,
|
||||
};
|
||||
|
||||
use super::prompt_request;
|
||||
use super::prompt_request::PromptRequest;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{completion::Prompt, providers::openai};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let comedian_agent = openai
|
||||
/// .agent("gpt-4o")
|
||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
/// .temperature(0.9)
|
||||
/// .build();
|
||||
///
|
||||
/// let response = comedian_agent.prompt("Entertain me!")
|
||||
/// .await
|
||||
/// .expect("Failed to prompt the agent");
|
||||
/// ```
|
||||
pub struct Agent<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
pub model: M,
|
||||
/// System prompt
|
||||
pub preamble: String,
|
||||
/// Context documents always available to the agent
|
||||
pub static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (identified by their name)
|
||||
pub static_tools: Vec<String>,
|
||||
/// Temperature of the model
|
||||
pub temperature: Option<f64>,
|
||||
/// Maximum number of tokens for the completion
|
||||
pub max_tokens: Option<u64>,
|
||||
/// Additional parameters to be passed to the model
|
||||
pub additional_params: Option<serde_json::Value>,
|
||||
/// List of vector store, with the sample number
|
||||
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||
/// Actual tool implementations
|
||||
pub tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||
async fn completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
let prompt = prompt.into();
|
||||
let rag_text = prompt.rag_text().clone();
|
||||
|
||||
let completion_request = self
|
||||
.model
|
||||
.completion_request(prompt)
|
||||
.preamble(self.preamble.clone())
|
||||
.messages(chat_history)
|
||||
.temperature_opt(self.temperature)
|
||||
.max_tokens_opt(self.max_tokens)
|
||||
.additional_params_opt(self.additional_params.clone())
|
||||
.documents(self.static_context.clone());
|
||||
|
||||
let agent = match &rag_text {
|
||||
Some(text) => {
|
||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id, doc)| {
|
||||
// Pretty print the document if possible for better readability
|
||||
let text = serde_json::to_string_pretty(&doc)
|
||||
.unwrap_or_else(|_| doc.to_string());
|
||||
|
||||
Document {
|
||||
id,
|
||||
text,
|
||||
additional_props: HashMap::new(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
acc.extend(docs);
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n_ids(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id)| id)
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
for doc in docs {
|
||||
if let Some(tool) = self.tools.get(&doc) {
|
||||
acc.push(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||
}
|
||||
}
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
Some(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request
|
||||
.documents(dynamic_context)
|
||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||
}
|
||||
None => {
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
// TODO: tool definitions should likely take an `Option<String>`
|
||||
Some(tool.definition("".into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request.tools(static_tools)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
}
|
||||
|
||||
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
||||
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
||||
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
||||
//
|
||||
// References:
|
||||
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||
fn prompt(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
) -> PromptRequest<M, prompt_request::Simple> {
|
||||
PromptRequest::new(self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||
fn prompt(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
) -> PromptRequest<M, prompt_request::Simple> {
|
||||
PromptRequest::new(*self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Chat for Agent<M> {
|
||||
async fn chat(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<String, PromptError> {
|
||||
let mut cloned_history = chat_history.clone();
|
||||
PromptRequest::new(self, prompt)
|
||||
.with_history(&mut cloned_history)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||
async fn stream_completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
// Reuse the existing completion implementation to build the request
|
||||
// This ensures streaming and non-streaming use the same request building logic
|
||||
self.completion(prompt, chat_history).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,179 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionModel, Document},
|
||||
tool::{Tool, ToolSet},
|
||||
vector_store::VectorStoreIndexDyn,
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
use super::Agent;
|
||||
|
||||
/// A builder for creating an agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = AgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .context("Context document 1")
|
||||
/// .context("Context document 2")
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct AgentBuilder<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: Option<String>,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (by name)
|
||||
static_tools: Vec<String>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Actual tool implementations
|
||||
tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> AgentBuilder<M> {
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
preamble: None,
|
||||
static_context: vec![],
|
||||
static_tools: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
dynamic_context: vec![],
|
||||
dynamic_tools: vec![],
|
||||
tools: ToolSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the system prompt
|
||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||
self.preamble = Some(preamble.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Append to the preamble of the agent
|
||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||
self.preamble = Some(format!(
|
||||
"{}\n{}",
|
||||
self.preamble.unwrap_or_else(|| "".into()),
|
||||
doc
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static context document to the agent
|
||||
pub fn context(mut self, doc: &str) -> Self {
|
||||
self.static_context.push(Document {
|
||||
id: format!("static_doc_{}", self.static_context.len()),
|
||||
text: doc.into(),
|
||||
additional_props: HashMap::new(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static tool to the agent
|
||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||
let toolname = tool.name();
|
||||
self.tools.add_tool(tool);
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
// Add an MCP tool to the agent
|
||||
#[cfg(feature = "mcp")]
|
||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||
mut self,
|
||||
tool: mcp_core::types::Tool,
|
||||
client: mcp_core::client::Client<T>,
|
||||
) -> Self {
|
||||
let toolname = tool.name.clone();
|
||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||
/// dynamic context will be inserted in the request.
|
||||
pub fn dynamic_context(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||
) -> Self {
|
||||
self.dynamic_context
|
||||
.push((sample, Box::new(dynamic_context)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||
/// dynamic toolset will be inserted in the request.
|
||||
pub fn dynamic_tools(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||
toolset: ToolSet,
|
||||
) -> Self {
|
||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||
self.tools.add_tools(toolset);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature of the model
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of tokens for the completion
|
||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set additional parameters to be passed to the model
|
||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the agent
|
||||
pub fn build(self) -> Agent<M> {
|
||||
Agent {
|
||||
model: self.model,
|
||||
preamble: self.preamble.unwrap_or_default(),
|
||||
static_context: self.static_context,
|
||||
static_tools: self.static_tools,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params,
|
||||
dynamic_context: self.dynamic_context,
|
||||
dynamic_tools: self.dynamic_tools,
|
||||
tools: self.tools,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||
//!
|
||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||
//!
|
||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||
//! context documents and tools.
|
||||
//!
|
||||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
||||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
||||
//! be used for generating chat completions.
|
||||
//!
|
||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::{Chat, Completion, Prompt},
|
||||
//! providers::openai,
|
||||
//! };
|
||||
//!
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the agent
|
||||
//! let agent = openai.agent("gpt-4o")
|
||||
//! .preamble("System prompt")
|
||||
//! .context("Context document 1")
|
||||
//! .context("Context document 2")
|
||||
//! .tool(tool1)
|
||||
//! .tool(tool2)
|
||||
//! .temperature(0.8)
|
||||
//! .additional_params(json!({"foo": "bar"}))
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the agent for completions and prompts
|
||||
//! // Generate a chat completion response from a prompt and chat history
|
||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to chat with Agent");
|
||||
//!
|
||||
//! // Generate a prompt completion response from a simple prompt
|
||||
//! let chat_response = agent.prompt("Prompt")
|
||||
//! .await
|
||||
//! .expect("Failed to prompt the Agent");
|
||||
//!
|
||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||
//! // model parameters, etc.), but these can be overwritten.
|
||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to create completion request builder");
|
||||
//!
|
||||
//! let response = completion_req_builder
|
||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||
//! .send()
|
||||
//! .await
|
||||
//! .expect("Failed to send completion request");
|
||||
//! ```
|
||||
//!
|
||||
//! RAG Agent example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::Prompt,
|
||||
//! embeddings::EmbeddingsBuilder,
|
||||
//! providers::openai,
|
||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||
//! };
|
||||
//!
|
||||
//! // Initialize OpenAI client
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Initialize OpenAI embedding model
|
||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||
//!
|
||||
//! // Create vector store, compute embeddings and load them in the store
|
||||
//! let mut vector_store = InMemoryVectorStore::default();
|
||||
//!
|
||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .expect("Failed to build embeddings");
|
||||
//!
|
||||
//! vector_store.add_documents(embeddings)
|
||||
//! .await
|
||||
//! .expect("Failed to add documents");
|
||||
//!
|
||||
//! // Create vector store index
|
||||
//! let index = vector_store.index(embedding_model);
|
||||
//!
|
||||
//! let agent = openai.agent(openai::GPT_4O)
|
||||
//! .preamble("
|
||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||
//! You will find additional non-standard word definitions that could be useful below.
|
||||
//! ")
|
||||
//! .dynamic_context(1, index)
|
||||
//! .build();
|
||||
//!
|
||||
//! // Prompt the agent and print the response
|
||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||
//! .expect("Failed to prompt the agent");
|
||||
//! ```
|
||||
|
||||
mod agent;
|
||||
mod builder;
|
||||
mod prompt_request;
|
||||
|
||||
pub use agent::Agent;
|
||||
pub use builder::AgentBuilder;
|
||||
pub use prompt_request::PromptRequest;
|
|
@ -0,0 +1,232 @@
|
|||
use std::{
|
||||
future::{Future, IntoFuture},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
|
||||
message::{AssistantContent, UserContent},
|
||||
tool::ToolSetError,
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use super::Agent;
|
||||
|
||||
pub trait State {}
|
||||
pub struct Simple;
|
||||
pub struct MultiTurn;
|
||||
|
||||
impl State for Simple {}
|
||||
impl State for MultiTurn {}
|
||||
|
||||
pub trait SendPromptRequest<M: CompletionModel, T: State> {
|
||||
fn send(self) -> impl Future<Output = Result<String, PromptError>> + Send;
|
||||
}
|
||||
|
||||
/// A builder for creating prompt requests with customizable options.
|
||||
/// Uses generics to track which options have been set during the build process.
|
||||
pub struct PromptRequest<'c, 'a, M: CompletionModel, T: State> {
|
||||
/// The prompt message to send to the model
|
||||
prompt: Message,
|
||||
/// Optional chat history to include with the prompt
|
||||
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
||||
chat_history: Option<&'c mut Vec<Message>>,
|
||||
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
||||
max_depth: usize,
|
||||
/// The agent to use for execution
|
||||
agent: &'a Agent<M>,
|
||||
|
||||
/// Typestate
|
||||
_state: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
|
||||
/// Create a new PromptRequest with the given prompt and model
|
||||
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
|
||||
Self {
|
||||
prompt: prompt.into(),
|
||||
chat_history: None,
|
||||
max_depth: 0,
|
||||
agent,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
|
||||
/// Set the maximum depth for multi-turn conversations
|
||||
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M, MultiTurn> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: self.chat_history,
|
||||
max_depth: depth,
|
||||
agent: self.agent,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add chat history to the prompt request
|
||||
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M, Simple> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: Some(history),
|
||||
max_depth: self.max_depth,
|
||||
agent: self.agent,
|
||||
_state: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
|
||||
/// future, we should be able to use `impl Future<...>` directly via the associated type.
|
||||
///
|
||||
/// Ref: https://github.com/rust-lang/rust/issues/63063
|
||||
impl<'c: 'a, 'a, M: CompletionModel, T: State + 'a> IntoFuture for PromptRequest<'c, 'a, M, T>
|
||||
where
|
||||
PromptRequest<'c, 'a, M, T>: SendPromptRequest<M, T>,
|
||||
{
|
||||
type Output = Result<String, PromptError>;
|
||||
type IntoFuture = BoxFuture<'a, Self::Output>;
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
self.send().boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, MultiTurn> {
|
||||
async fn send(self) -> Result<String, PromptError> {
|
||||
let agent = self.agent;
|
||||
let mut prompt = self.prompt;
|
||||
let chat_history = if let Some(history) = self.chat_history {
|
||||
history
|
||||
} else {
|
||||
&mut Vec::new()
|
||||
};
|
||||
|
||||
let mut current_max_depth = 0;
|
||||
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
|
||||
while current_max_depth <= self.max_depth + 1 {
|
||||
current_max_depth += 1;
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!(
|
||||
"Current conversation depth: {}/{}",
|
||||
current_max_depth,
|
||||
self.max_depth
|
||||
);
|
||||
}
|
||||
|
||||
let resp = agent
|
||||
.completion(prompt.clone(), chat_history.to_vec())
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
chat_history.push(prompt);
|
||||
|
||||
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
||||
.choice
|
||||
.iter()
|
||||
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
||||
|
||||
chat_history.push(Message::Assistant {
|
||||
content: resp.choice.clone(),
|
||||
});
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let merged_texts = texts
|
||||
.into_iter()
|
||||
.filter_map(|content| {
|
||||
if let AssistantContent::Text(text) = content {
|
||||
Some(text.text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
||||
}
|
||||
|
||||
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
||||
return Ok(merged_texts);
|
||||
}
|
||||
|
||||
let tool_content = stream::iter(tool_calls)
|
||||
.then(async |choice| {
|
||||
if let AssistantContent::ToolCall(tool_call) = choice {
|
||||
let output = agent
|
||||
.tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await?;
|
||||
Ok(UserContent::tool_result(
|
||||
tool_call.id.clone(),
|
||||
OneOrMany::one(output.into()),
|
||||
))
|
||||
} else {
|
||||
unreachable!(
|
||||
"This should never happen as we already filtered for `ToolCall`"
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
prompt = Message::User {
|
||||
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
||||
};
|
||||
}
|
||||
|
||||
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
||||
Err(PromptError::MaxDepthError {
|
||||
max_depth: self.max_depth,
|
||||
chat_history: chat_history.clone(),
|
||||
prompt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, Simple> {
|
||||
async fn send(self) -> Result<String, PromptError> {
|
||||
let chat_history = if let Some(history) = self.chat_history {
|
||||
history.clone()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.agent
|
||||
.completion(self.prompt, chat_history)
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
tracing::debug!(?resp.choice);
|
||||
|
||||
if resp.choice.len() > 1 {
|
||||
tracing::warn!("Parallel tool calls are only available when using multi turn. Use `agent.prompt(...).multi_turn(depth).await`!");
|
||||
}
|
||||
|
||||
match resp.choice.first() {
|
||||
AssistantContent::Text(text) => Ok(text.text.clone()),
|
||||
AssistantContent::ToolCall(tool_call) => Ok(self
|
||||
.agent
|
||||
.tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await?),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -477,6 +477,12 @@ impl From<String> for Text {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&String> for Text {
|
||||
fn from(text: &String) -> Self {
|
||||
text.to_owned().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Text {
|
||||
fn from(text: &str) -> Self {
|
||||
text.to_owned().into()
|
||||
|
@ -507,6 +513,14 @@ impl From<&str> for Message {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&String> for Message {
|
||||
fn from(text: &String) -> Self {
|
||||
Message::User {
|
||||
content: OneOrMany::one(UserContent::Text(text.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Text> for Message {
|
||||
fn from(text: Text) -> Self {
|
||||
Message::User {
|
||||
|
|
|
@ -37,6 +37,7 @@ use serde_json::json;
|
|||
use crate::{
|
||||
agent::{Agent, AgentBuilder},
|
||||
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||
message::Message,
|
||||
tool::Tool,
|
||||
};
|
||||
|
||||
|
@ -62,7 +63,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
|||
where
|
||||
M: Sync,
|
||||
{
|
||||
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
|
||||
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
|
||||
let summary = self.agent.prompt(text).await?;
|
||||
|
||||
if summary.is_empty() {
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use std::future::IntoFuture;
|
||||
|
||||
use crate::{
|
||||
completion::{self, CompletionModel},
|
||||
extractor::{ExtractionError, Extractor},
|
||||
vector_store,
|
||||
completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, message::Message, vector_store
|
||||
};
|
||||
|
||||
use super::Op;
|
||||
|
@ -129,13 +127,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
|
|||
where
|
||||
M: CompletionModel,
|
||||
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
||||
Input: Into<String> + Send + Sync,
|
||||
Input: Into<Message> + Send + Sync,
|
||||
{
|
||||
type Input = Input;
|
||||
type Output = Result<Output, ExtractionError>;
|
||||
|
||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
||||
self.extractor.extract(&input.into()).await
|
||||
self.extractor.extract(input).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue