mirror of https://github.com/0xplaygrounds/rig
refactor: blow up agent and introduce typestate
This commit is contained in:
parent
4b38294264
commit
44aa8e9f13
|
@ -0,0 +1,254 @@
|
||||||
|
use rig::{
|
||||||
|
agent::Agent,
|
||||||
|
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||||
|
extractor::Extractor,
|
||||||
|
message::Message,
|
||||||
|
providers::anthropic,
|
||||||
|
tool::Tool,
|
||||||
|
};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
const CHAIN_OF_THOUGHT_PROMPT: &str = "
|
||||||
|
You are an assistant that extracts reasoning steps from a given prompt.
|
||||||
|
Do not return text, only return a tool call.
|
||||||
|
";
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
|
||||||
|
struct ChainOfThoughtSteps {
|
||||||
|
steps: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReasoningAgent<M: CompletionModel> {
|
||||||
|
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
|
||||||
|
executor: Agent<M>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||||
|
let prompt: Message = prompt.into();
|
||||||
|
let mut chat_history = vec![prompt.clone()];
|
||||||
|
let extracted = self
|
||||||
|
.chain_of_thought_extractor
|
||||||
|
.extract(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Extraction error: {:?}", e);
|
||||||
|
CompletionError::ProviderError("".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if extracted.steps.is_empty() {
|
||||||
|
return Ok("No reasoning steps provided.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut reasoning_prompt = String::new();
|
||||||
|
for (i, step) in extracted.steps.iter().enumerate() {
|
||||||
|
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.executor
|
||||||
|
.prompt(reasoning_prompt.as_str())
|
||||||
|
.with_history(&mut chat_history)
|
||||||
|
.multi_turn(20)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_max_level(tracing::Level::DEBUG)
|
||||||
|
.with_target(false)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// Create OpenAI client
|
||||||
|
let openai_client = anthropic::Client::from_env();
|
||||||
|
|
||||||
|
let agent = ReasoningAgent {
|
||||||
|
chain_of_thought_extractor: openai_client
|
||||||
|
.extractor(anthropic::CLAUDE_3_5_SONNET)
|
||||||
|
.preamble(CHAIN_OF_THOUGHT_PROMPT)
|
||||||
|
.build(),
|
||||||
|
|
||||||
|
executor: openai_client
|
||||||
|
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||||
|
.preamble(
|
||||||
|
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||||
|
Follow these instructions closely.
|
||||||
|
1. Consider the user's request carefully and identify the core elements of the request.
|
||||||
|
2. Select which tool among those made available to you is appropriate given the context.
|
||||||
|
3. This is very important: never perform the operation yourself.
|
||||||
|
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
|
||||||
|
"
|
||||||
|
)
|
||||||
|
.tool(Add)
|
||||||
|
.tool(Subtract)
|
||||||
|
.tool(Multiply)
|
||||||
|
.tool(Divide)
|
||||||
|
.build(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Prompt the agent and print the response
|
||||||
|
let result = agent.prompt("Calculate x for the equation: `20x + 23 = 400x / (1 - x)`").await?;
|
||||||
|
|
||||||
|
println!("\n\nReasoning Agent: {}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Add;
|
||||||
|
impl Tool for Add {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "add",
|
||||||
|
"description": "Add x and y together",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Subtract;
|
||||||
|
impl Tool for Subtract {
|
||||||
|
const NAME: &'static str = "subtract";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "subtract",
|
||||||
|
"description": "Subtract y from x (i.e.: x - y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract from"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x - args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Multiply;
|
||||||
|
impl Tool for Multiply {
|
||||||
|
const NAME: &'static str = "multiply";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "Compute the product of x and y (i.e.: x * y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first factor in the product"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second factor in the product"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x * args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Divide;
|
||||||
|
impl Tool for Divide {
|
||||||
|
const NAME: &'static str = "divide";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "divide",
|
||||||
|
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The Dividend of the division. The number being divided"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The Divisor of the division. The number by which the dividend is being divided"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x / args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,684 +0,0 @@
|
||||||
//! This module contains the implementation of the [Agent] struct and its builder.
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
|
||||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
|
||||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
|
||||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
|
||||||
//! context documents and tools.
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
|
||||||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
|
||||||
//! be used for generating chat completions.
|
|
||||||
//!
|
|
||||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
|
||||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
|
||||||
//! before building the agent.
|
|
||||||
//!
|
|
||||||
//! # Example
|
|
||||||
//! ```rust
|
|
||||||
//! use rig::{
|
|
||||||
//! completion::{Chat, Completion, Prompt},
|
|
||||||
//! providers::openai,
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! let openai = openai::Client::from_env();
|
|
||||||
//!
|
|
||||||
//! // Configure the agent
|
|
||||||
//! let agent = openai.agent("gpt-4o")
|
|
||||||
//! .preamble("System prompt")
|
|
||||||
//! .context("Context document 1")
|
|
||||||
//! .context("Context document 2")
|
|
||||||
//! .tool(tool1)
|
|
||||||
//! .tool(tool2)
|
|
||||||
//! .temperature(0.8)
|
|
||||||
//! .additional_params(json!({"foo": "bar"}))
|
|
||||||
//! .build();
|
|
||||||
//!
|
|
||||||
//! // Use the agent for completions and prompts
|
|
||||||
//! // Generate a chat completion response from a prompt and chat history
|
|
||||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to chat with Agent");
|
|
||||||
//!
|
|
||||||
//! // Generate a prompt completion response from a simple prompt
|
|
||||||
//! let chat_response = agent.prompt("Prompt")
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to prompt the Agent");
|
|
||||||
//!
|
|
||||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
|
||||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
|
||||||
//! // model parameters, etc.), but these can be overwritten.
|
|
||||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to create completion request builder");
|
|
||||||
//!
|
|
||||||
//! let response = completion_req_builder
|
|
||||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
|
||||||
//! .send()
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to send completion request");
|
|
||||||
//! ```
|
|
||||||
//!
|
|
||||||
//! RAG Agent example
|
|
||||||
//! ```rust
|
|
||||||
//! use rig::{
|
|
||||||
//! completion::Prompt,
|
|
||||||
//! embeddings::EmbeddingsBuilder,
|
|
||||||
//! providers::openai,
|
|
||||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! // Initialize OpenAI client
|
|
||||||
//! let openai = openai::Client::from_env();
|
|
||||||
//!
|
|
||||||
//! // Initialize OpenAI embedding model
|
|
||||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
|
||||||
//!
|
|
||||||
//! // Create vector store, compute embeddings and load them in the store
|
|
||||||
//! let mut vector_store = InMemoryVectorStore::default();
|
|
||||||
//!
|
|
||||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
||||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
|
||||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
|
||||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
|
||||||
//! .build()
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to build embeddings");
|
|
||||||
//!
|
|
||||||
//! vector_store.add_documents(embeddings)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to add documents");
|
|
||||||
//!
|
|
||||||
//! // Create vector store index
|
|
||||||
//! let index = vector_store.index(embedding_model);
|
|
||||||
//!
|
|
||||||
//! let agent = openai.agent(openai::GPT_4O)
|
|
||||||
//! .preamble("
|
|
||||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
|
||||||
//! You will find additional non-standard word definitions that could be useful below.
|
|
||||||
//! ")
|
|
||||||
//! .dynamic_context(1, index)
|
|
||||||
//! .build();
|
|
||||||
//!
|
|
||||||
//! // Prompt the agent and print the response
|
|
||||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
|
||||||
//! .expect("Failed to prompt the agent");
|
|
||||||
//! ```
|
|
||||||
use std::{collections::HashMap, future::IntoFuture};
|
|
||||||
|
|
||||||
use futures::{future::BoxFuture, stream, FutureExt, StreamExt, TryStreamExt};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
completion::{
|
|
||||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
|
||||||
Message, Prompt, PromptError,
|
|
||||||
},
|
|
||||||
message::{AssistantContent, UserContent},
|
|
||||||
streaming::{
|
|
||||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
|
||||||
StreamingResult,
|
|
||||||
},
|
|
||||||
tool::{Tool, ToolSet, ToolSetError},
|
|
||||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
|
||||||
OneOrMany,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
use crate::tool::McpTool;
|
|
||||||
|
|
||||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
|
||||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
|
||||||
/// All context documents and tools are always provided to the agent when prompted.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::{completion::Prompt, providers::openai};
|
|
||||||
///
|
|
||||||
/// let openai = openai::Client::from_env();
|
|
||||||
///
|
|
||||||
/// let comedian_agent = openai
|
|
||||||
/// .agent("gpt-4o")
|
|
||||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
|
||||||
/// .temperature(0.9)
|
|
||||||
/// .build();
|
|
||||||
///
|
|
||||||
/// let response = comedian_agent.prompt("Entertain me!")
|
|
||||||
/// .await
|
|
||||||
/// .expect("Failed to prompt the agent");
|
|
||||||
/// ```
|
|
||||||
pub struct Agent<M: CompletionModel> {
|
|
||||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
|
||||||
model: M,
|
|
||||||
/// System prompt
|
|
||||||
preamble: String,
|
|
||||||
/// Context documents always available to the agent
|
|
||||||
static_context: Vec<Document>,
|
|
||||||
/// Tools that are always available to the agent (identified by their name)
|
|
||||||
static_tools: Vec<String>,
|
|
||||||
/// Temperature of the model
|
|
||||||
temperature: Option<f64>,
|
|
||||||
/// Maximum number of tokens for the completion
|
|
||||||
max_tokens: Option<u64>,
|
|
||||||
/// Additional parameters to be passed to the model
|
|
||||||
additional_params: Option<serde_json::Value>,
|
|
||||||
/// List of vector store, with the sample number
|
|
||||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Dynamic tools
|
|
||||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Actual tool implementations
|
|
||||||
pub tools: ToolSet,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
|
||||||
async fn completion(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
|
||||||
let prompt = prompt.into();
|
|
||||||
let rag_text = prompt.rag_text().clone();
|
|
||||||
|
|
||||||
let completion_request = self
|
|
||||||
.model
|
|
||||||
.completion_request(prompt)
|
|
||||||
.preamble(self.preamble.clone())
|
|
||||||
.messages(chat_history)
|
|
||||||
.temperature_opt(self.temperature)
|
|
||||||
.max_tokens_opt(self.max_tokens)
|
|
||||||
.additional_params_opt(self.additional_params.clone())
|
|
||||||
.documents(self.static_context.clone());
|
|
||||||
|
|
||||||
let agent = match &rag_text {
|
|
||||||
Some(text) => {
|
|
||||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
|
||||||
.then(|(num_sample, index)| async {
|
|
||||||
Ok::<_, VectorStoreError>(
|
|
||||||
index
|
|
||||||
.top_n(text, *num_sample)
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|(_, id, doc)| {
|
|
||||||
// Pretty print the document if possible for better readability
|
|
||||||
let text = serde_json::to_string_pretty(&doc)
|
|
||||||
.unwrap_or_else(|_| doc.to_string());
|
|
||||||
|
|
||||||
Document {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
additional_props: HashMap::new(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.try_fold(vec![], |mut acc, docs| async {
|
|
||||||
acc.extend(docs);
|
|
||||||
Ok(acc)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
|
||||||
|
|
||||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
|
||||||
.then(|(num_sample, index)| async {
|
|
||||||
Ok::<_, VectorStoreError>(
|
|
||||||
index
|
|
||||||
.top_n_ids(text, *num_sample)
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|(_, id)| id)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.try_fold(vec![], |mut acc, docs| async {
|
|
||||||
for doc in docs {
|
|
||||||
if let Some(tool) = self.tools.get(&doc) {
|
|
||||||
acc.push(tool.definition(text.into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(acc)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
|
||||||
|
|
||||||
let static_tools = stream::iter(self.static_tools.iter())
|
|
||||||
.filter_map(|toolname| async move {
|
|
||||||
if let Some(tool) = self.tools.get(toolname) {
|
|
||||||
Some(tool.definition(text.into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!(
|
|
||||||
"Tool implementation not found in toolset: {}",
|
|
||||||
toolname
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
completion_request
|
|
||||||
.documents(dynamic_context)
|
|
||||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
let static_tools = stream::iter(self.static_tools.iter())
|
|
||||||
.filter_map(|toolname| async move {
|
|
||||||
if let Some(tool) = self.tools.get(toolname) {
|
|
||||||
// TODO: tool definitions should likely take an `Option<String>`
|
|
||||||
Some(tool.definition("".into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!(
|
|
||||||
"Tool implementation not found in toolset: {}",
|
|
||||||
toolname
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
completion_request.tools(static_tools)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(agent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
|
||||||
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
|
||||||
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
|
||||||
//
|
|
||||||
// References:
|
|
||||||
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
|
||||||
|
|
||||||
#[allow(refining_impl_trait)]
|
|
||||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
|
||||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
|
||||||
PromptRequest::new(self, prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(refining_impl_trait)]
|
|
||||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
|
||||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
|
||||||
PromptRequest::new(*self, prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(refining_impl_trait)]
|
|
||||||
impl<M: CompletionModel> Chat for Agent<M> {
|
|
||||||
async fn chat(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<String, PromptError> {
|
|
||||||
let mut cloned_history = chat_history.clone();
|
|
||||||
PromptRequest::new(self, prompt)
|
|
||||||
.with_history(&mut cloned_history)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A builder for creating an agent
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
|
||||||
///
|
|
||||||
/// let openai = openai::Client::from_env();
|
|
||||||
///
|
|
||||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
|
||||||
///
|
|
||||||
/// // Configure the agent
|
|
||||||
/// let agent = AgentBuilder::new(model)
|
|
||||||
/// .preamble("System prompt")
|
|
||||||
/// .context("Context document 1")
|
|
||||||
/// .context("Context document 2")
|
|
||||||
/// .tool(tool1)
|
|
||||||
/// .tool(tool2)
|
|
||||||
/// .temperature(0.8)
|
|
||||||
/// .additional_params(json!({"foo": "bar"}))
|
|
||||||
/// .build();
|
|
||||||
/// ```
|
|
||||||
pub struct AgentBuilder<M: CompletionModel> {
|
|
||||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
|
||||||
model: M,
|
|
||||||
/// System prompt
|
|
||||||
preamble: Option<String>,
|
|
||||||
/// Context documents always available to the agent
|
|
||||||
static_context: Vec<Document>,
|
|
||||||
/// Tools that are always available to the agent (by name)
|
|
||||||
static_tools: Vec<String>,
|
|
||||||
/// Additional parameters to be passed to the model
|
|
||||||
additional_params: Option<serde_json::Value>,
|
|
||||||
/// Maximum number of tokens for the completion
|
|
||||||
max_tokens: Option<u64>,
|
|
||||||
/// List of vector store, with the sample number
|
|
||||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Dynamic tools
|
|
||||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Temperature of the model
|
|
||||||
temperature: Option<f64>,
|
|
||||||
/// Actual tool implementations
|
|
||||||
tools: ToolSet,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> AgentBuilder<M> {
|
|
||||||
pub fn new(model: M) -> Self {
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
preamble: None,
|
|
||||||
static_context: vec![],
|
|
||||||
static_tools: vec![],
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
additional_params: None,
|
|
||||||
dynamic_context: vec![],
|
|
||||||
dynamic_tools: vec![],
|
|
||||||
tools: ToolSet::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the system prompt
|
|
||||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
|
||||||
self.preamble = Some(preamble.into());
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append to the preamble of the agent
|
|
||||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
|
||||||
self.preamble = Some(format!(
|
|
||||||
"{}\n{}",
|
|
||||||
self.preamble.unwrap_or_else(|| "".into()),
|
|
||||||
doc
|
|
||||||
));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a static context document to the agent
|
|
||||||
pub fn context(mut self, doc: &str) -> Self {
|
|
||||||
self.static_context.push(Document {
|
|
||||||
id: format!("static_doc_{}", self.static_context.len()),
|
|
||||||
text: doc.into(),
|
|
||||||
additional_props: HashMap::new(),
|
|
||||||
});
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a static tool to the agent
|
|
||||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
|
||||||
let toolname = tool.name();
|
|
||||||
self.tools.add_tool(tool);
|
|
||||||
self.static_tools.push(toolname);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add an MCP tool to the agent
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
|
||||||
mut self,
|
|
||||||
tool: mcp_core::types::Tool,
|
|
||||||
client: mcp_core::client::Client<T>,
|
|
||||||
) -> Self {
|
|
||||||
let toolname = tool.name.clone();
|
|
||||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
|
||||||
self.static_tools.push(toolname);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
|
||||||
/// dynamic context will be inserted in the request.
|
|
||||||
pub fn dynamic_context(
|
|
||||||
mut self,
|
|
||||||
sample: usize,
|
|
||||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
|
||||||
) -> Self {
|
|
||||||
self.dynamic_context
|
|
||||||
.push((sample, Box::new(dynamic_context)));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
|
||||||
/// dynamic toolset will be inserted in the request.
|
|
||||||
pub fn dynamic_tools(
|
|
||||||
mut self,
|
|
||||||
sample: usize,
|
|
||||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
|
||||||
toolset: ToolSet,
|
|
||||||
) -> Self {
|
|
||||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
|
||||||
self.tools.add_tools(toolset);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the temperature of the model
|
|
||||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
|
||||||
self.temperature = Some(temperature);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the maximum number of tokens for the completion
|
|
||||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
|
||||||
self.max_tokens = Some(max_tokens);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set additional parameters to be passed to the model
|
|
||||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
|
||||||
self.additional_params = Some(params);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the agent
|
|
||||||
pub fn build(self) -> Agent<M> {
|
|
||||||
Agent {
|
|
||||||
model: self.model,
|
|
||||||
preamble: self.preamble.unwrap_or_default(),
|
|
||||||
static_context: self.static_context,
|
|
||||||
static_tools: self.static_tools,
|
|
||||||
temperature: self.temperature,
|
|
||||||
max_tokens: self.max_tokens,
|
|
||||||
additional_params: self.additional_params,
|
|
||||||
dynamic_context: self.dynamic_context,
|
|
||||||
dynamic_tools: self.dynamic_tools,
|
|
||||||
tools: self.tools,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A builder for creating prompt requests with customizable options.
|
|
||||||
/// Uses generics to track which options have been set during the build process.
|
|
||||||
pub struct PromptRequest<'c, 'a, M: CompletionModel> {
|
|
||||||
/// The prompt message to send to the model
|
|
||||||
prompt: Message,
|
|
||||||
/// Optional chat history to include with the prompt
|
|
||||||
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
|
||||||
chat_history: Option<&'c mut Vec<Message>>,
|
|
||||||
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
|
||||||
max_depth: usize,
|
|
||||||
/// The agent to use for execution
|
|
||||||
agent: &'a Agent<M>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
|
|
||||||
/// Create a new PromptRequest with the given prompt and model
|
|
||||||
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
|
|
||||||
Self {
|
|
||||||
prompt: prompt.into(),
|
|
||||||
chat_history: None,
|
|
||||||
max_depth: 0,
|
|
||||||
agent,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
|
|
||||||
/// Set the maximum depth for multi-turn conversations
|
|
||||||
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M> {
|
|
||||||
PromptRequest {
|
|
||||||
prompt: self.prompt,
|
|
||||||
chat_history: self.chat_history,
|
|
||||||
max_depth: depth,
|
|
||||||
agent: self.agent,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add chat history to the prompt request
|
|
||||||
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M> {
|
|
||||||
PromptRequest {
|
|
||||||
prompt: self.prompt,
|
|
||||||
chat_history: Some(history),
|
|
||||||
max_depth: self.max_depth,
|
|
||||||
agent: self.agent,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
|
|
||||||
/// future, we should be able to use `impl Future<...>` directly via the associated type.
|
|
||||||
///
|
|
||||||
/// Ref: https://github.com/rust-lang/rust/issues/63063
|
|
||||||
impl<'c: 'a, 'a, M: CompletionModel + 'c> IntoFuture for PromptRequest<'c, 'a, M> {
|
|
||||||
type Output = Result<String, PromptError>;
|
|
||||||
type IntoFuture = BoxFuture<'a, Self::Output>;
|
|
||||||
|
|
||||||
fn into_future(self) -> Self::IntoFuture {
|
|
||||||
self.send().boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementation for Agent
|
|
||||||
impl<M: CompletionModel> PromptRequest<'_, '_, M> {
|
|
||||||
async fn send(self) -> Result<String, PromptError> {
|
|
||||||
let agent = self.agent;
|
|
||||||
let mut prompt = self.prompt;
|
|
||||||
let chat_history = if let Some(history) = self.chat_history {
|
|
||||||
history
|
|
||||||
} else {
|
|
||||||
&mut Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut current_max_depth = 0;
|
|
||||||
while current_max_depth <= self.max_depth {
|
|
||||||
current_max_depth += 1;
|
|
||||||
|
|
||||||
if self.max_depth > 1 {
|
|
||||||
tracing::info!(
|
|
||||||
"Current conversation depth: {}/{}",
|
|
||||||
current_max_depth,
|
|
||||||
self.max_depth
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp = agent
|
|
||||||
.completion(prompt.clone(), chat_history.to_vec())
|
|
||||||
.await?
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
chat_history.push(prompt);
|
|
||||||
|
|
||||||
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
|
||||||
.choice
|
|
||||||
.iter()
|
|
||||||
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
|
||||||
|
|
||||||
chat_history.push(Message::Assistant {
|
|
||||||
content: resp.choice.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
|
||||||
let merged_texts = texts
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|content| {
|
|
||||||
if let AssistantContent::Text(text) = content {
|
|
||||||
Some(text.text.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
if self.max_depth > 1 {
|
|
||||||
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
|
||||||
return Ok(merged_texts);
|
|
||||||
}
|
|
||||||
|
|
||||||
let tool_content = stream::iter(tool_calls)
|
|
||||||
.then(async |choice| {
|
|
||||||
if let AssistantContent::ToolCall(tool_call) = choice {
|
|
||||||
let output = agent
|
|
||||||
.tools
|
|
||||||
.call(
|
|
||||||
&tool_call.function.name,
|
|
||||||
tool_call.function.arguments.to_string(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(UserContent::tool_result(
|
|
||||||
tool_call.id.clone(),
|
|
||||||
OneOrMany::one(output.into()),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
unreachable!(
|
|
||||||
"This should never happen as we already filtered for `ToolCall`"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
|
||||||
.await
|
|
||||||
.into_iter()
|
|
||||||
.collect::<Result<Vec<_>, _>>()
|
|
||||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
|
||||||
|
|
||||||
prompt = Message::User {
|
|
||||||
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
|
||||||
Err(PromptError::MaxDepthError {
|
|
||||||
max_depth: self.max_depth,
|
|
||||||
chat_history: chat_history.clone(),
|
|
||||||
prompt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
|
||||||
async fn stream_completion(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
|
||||||
// Reuse the existing completion implementation to build the request
|
|
||||||
// This ensures streaming and non-streaming use the same request building logic
|
|
||||||
self.completion(prompt, chat_history).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
|
||||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
|
||||||
self.stream_chat(prompt, vec![]).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
|
||||||
async fn stream_chat(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
|
||||||
self.stream_completion(prompt, chat_history)
|
|
||||||
.await?
|
|
||||||
.stream()
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use futures::{stream, StreamExt, TryStreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{
|
||||||
|
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||||
|
Message, Prompt, PromptError,
|
||||||
|
},
|
||||||
|
streaming::{
|
||||||
|
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||||
|
StreamingResult,
|
||||||
|
},
|
||||||
|
tool::ToolSet,
|
||||||
|
vector_store::VectorStoreError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::prompt_request;
|
||||||
|
use super::prompt_request::PromptRequest;
|
||||||
|
|
||||||
|
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||||
|
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||||
|
/// All context documents and tools are always provided to the agent when prompted.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::{completion::Prompt, providers::openai};
|
||||||
|
///
|
||||||
|
/// let openai = openai::Client::from_env();
|
||||||
|
///
|
||||||
|
/// let comedian_agent = openai
|
||||||
|
/// .agent("gpt-4o")
|
||||||
|
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||||
|
/// .temperature(0.9)
|
||||||
|
/// .build();
|
||||||
|
///
|
||||||
|
/// let response = comedian_agent.prompt("Entertain me!")
|
||||||
|
/// .await
|
||||||
|
/// .expect("Failed to prompt the agent");
|
||||||
|
/// ```
|
||||||
|
pub struct Agent<M: CompletionModel> {
|
||||||
|
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||||
|
pub model: M,
|
||||||
|
/// System prompt
|
||||||
|
pub preamble: String,
|
||||||
|
/// Context documents always available to the agent
|
||||||
|
pub static_context: Vec<Document>,
|
||||||
|
/// Tools that are always available to the agent (identified by their name)
|
||||||
|
pub static_tools: Vec<String>,
|
||||||
|
/// Temperature of the model
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
/// Maximum number of tokens for the completion
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
/// Additional parameters to be passed to the model
|
||||||
|
pub additional_params: Option<serde_json::Value>,
|
||||||
|
/// List of vector store, with the sample number
|
||||||
|
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||||
|
/// Dynamic tools
|
||||||
|
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||||
|
/// Actual tool implementations
|
||||||
|
pub tools: ToolSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||||
|
async fn completion(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||||
|
let prompt = prompt.into();
|
||||||
|
let rag_text = prompt.rag_text().clone();
|
||||||
|
|
||||||
|
let completion_request = self
|
||||||
|
.model
|
||||||
|
.completion_request(prompt)
|
||||||
|
.preamble(self.preamble.clone())
|
||||||
|
.messages(chat_history)
|
||||||
|
.temperature_opt(self.temperature)
|
||||||
|
.max_tokens_opt(self.max_tokens)
|
||||||
|
.additional_params_opt(self.additional_params.clone())
|
||||||
|
.documents(self.static_context.clone());
|
||||||
|
|
||||||
|
let agent = match &rag_text {
|
||||||
|
Some(text) => {
|
||||||
|
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||||
|
.then(|(num_sample, index)| async {
|
||||||
|
Ok::<_, VectorStoreError>(
|
||||||
|
index
|
||||||
|
.top_n(text, *num_sample)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, id, doc)| {
|
||||||
|
// Pretty print the document if possible for better readability
|
||||||
|
let text = serde_json::to_string_pretty(&doc)
|
||||||
|
.unwrap_or_else(|_| doc.to_string());
|
||||||
|
|
||||||
|
Document {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
additional_props: HashMap::new(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.try_fold(vec![], |mut acc, docs| async {
|
||||||
|
acc.extend(docs);
|
||||||
|
Ok(acc)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||||
|
.then(|(num_sample, index)| async {
|
||||||
|
Ok::<_, VectorStoreError>(
|
||||||
|
index
|
||||||
|
.top_n_ids(text, *num_sample)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, id)| id)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.try_fold(vec![], |mut acc, docs| async {
|
||||||
|
for doc in docs {
|
||||||
|
if let Some(tool) = self.tools.get(&doc) {
|
||||||
|
acc.push(tool.definition(text.into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(acc)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
let static_tools = stream::iter(self.static_tools.iter())
|
||||||
|
.filter_map(|toolname| async move {
|
||||||
|
if let Some(tool) = self.tools.get(toolname) {
|
||||||
|
Some(tool.definition(text.into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Tool implementation not found in toolset: {}",
|
||||||
|
toolname
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
completion_request
|
||||||
|
.documents(dynamic_context)
|
||||||
|
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let static_tools = stream::iter(self.static_tools.iter())
|
||||||
|
.filter_map(|toolname| async move {
|
||||||
|
if let Some(tool) = self.tools.get(toolname) {
|
||||||
|
// TODO: tool definitions should likely take an `Option<String>`
|
||||||
|
Some(tool.definition("".into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Tool implementation not found in toolset: {}",
|
||||||
|
toolname
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
completion_request.tools(static_tools)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(agent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
||||||
|
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
||||||
|
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
||||||
|
//
|
||||||
|
// References:
|
||||||
|
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||||
|
fn prompt(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
) -> PromptRequest<M, prompt_request::Simple> {
|
||||||
|
PromptRequest::new(self, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||||
|
fn prompt(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
) -> PromptRequest<M, prompt_request::Simple> {
|
||||||
|
PromptRequest::new(*self, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Chat for Agent<M> {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<String, PromptError> {
|
||||||
|
let mut cloned_history = chat_history.clone();
|
||||||
|
PromptRequest::new(self, prompt)
|
||||||
|
.with_history(&mut cloned_history)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||||
|
async fn stream_completion(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||||
|
// Reuse the existing completion implementation to build the request
|
||||||
|
// This ensures streaming and non-streaming use the same request building logic
|
||||||
|
self.completion(prompt, chat_history).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||||
|
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||||
|
self.stream_chat(prompt, vec![]).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||||
|
async fn stream_chat(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<StreamingResult, CompletionError> {
|
||||||
|
self.stream_completion(prompt, chat_history)
|
||||||
|
.await?
|
||||||
|
.stream()
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,179 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{CompletionModel, Document},
|
||||||
|
tool::{Tool, ToolSet},
|
||||||
|
vector_store::VectorStoreIndexDyn,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
use crate::tool::McpTool;
|
||||||
|
|
||||||
|
use super::Agent;
|
||||||
|
|
||||||
|
/// A builder for creating an agent
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||||
|
///
|
||||||
|
/// let openai = openai::Client::from_env();
|
||||||
|
///
|
||||||
|
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||||
|
///
|
||||||
|
/// // Configure the agent
|
||||||
|
/// let agent = AgentBuilder::new(model)
|
||||||
|
/// .preamble("System prompt")
|
||||||
|
/// .context("Context document 1")
|
||||||
|
/// .context("Context document 2")
|
||||||
|
/// .tool(tool1)
|
||||||
|
/// .tool(tool2)
|
||||||
|
/// .temperature(0.8)
|
||||||
|
/// .additional_params(json!({"foo": "bar"}))
|
||||||
|
/// .build();
|
||||||
|
/// ```
|
||||||
|
pub struct AgentBuilder<M: CompletionModel> {
|
||||||
|
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||||
|
model: M,
|
||||||
|
/// System prompt
|
||||||
|
preamble: Option<String>,
|
||||||
|
/// Context documents always available to the agent
|
||||||
|
static_context: Vec<Document>,
|
||||||
|
/// Tools that are always available to the agent (by name)
|
||||||
|
static_tools: Vec<String>,
|
||||||
|
/// Additional parameters to be passed to the model
|
||||||
|
additional_params: Option<serde_json::Value>,
|
||||||
|
/// Maximum number of tokens for the completion
|
||||||
|
max_tokens: Option<u64>,
|
||||||
|
/// List of vector store, with the sample number
|
||||||
|
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||||
|
/// Dynamic tools
|
||||||
|
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||||
|
/// Temperature of the model
|
||||||
|
temperature: Option<f64>,
|
||||||
|
/// Actual tool implementations
|
||||||
|
tools: ToolSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> AgentBuilder<M> {
|
||||||
|
pub fn new(model: M) -> Self {
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
preamble: None,
|
||||||
|
static_context: vec![],
|
||||||
|
static_tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
additional_params: None,
|
||||||
|
dynamic_context: vec![],
|
||||||
|
dynamic_tools: vec![],
|
||||||
|
tools: ToolSet::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the system prompt
|
||||||
|
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||||
|
self.preamble = Some(preamble.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append to the preamble of the agent
|
||||||
|
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||||
|
self.preamble = Some(format!(
|
||||||
|
"{}\n{}",
|
||||||
|
self.preamble.unwrap_or_else(|| "".into()),
|
||||||
|
doc
|
||||||
|
));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a static context document to the agent
|
||||||
|
pub fn context(mut self, doc: &str) -> Self {
|
||||||
|
self.static_context.push(Document {
|
||||||
|
id: format!("static_doc_{}", self.static_context.len()),
|
||||||
|
text: doc.into(),
|
||||||
|
additional_props: HashMap::new(),
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a static tool to the agent
|
||||||
|
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||||
|
let toolname = tool.name();
|
||||||
|
self.tools.add_tool(tool);
|
||||||
|
self.static_tools.push(toolname);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an MCP tool to the agent
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||||
|
mut self,
|
||||||
|
tool: mcp_core::types::Tool,
|
||||||
|
client: mcp_core::client::Client<T>,
|
||||||
|
) -> Self {
|
||||||
|
let toolname = tool.name.clone();
|
||||||
|
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||||
|
self.static_tools.push(toolname);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||||
|
/// dynamic context will be inserted in the request.
|
||||||
|
pub fn dynamic_context(
|
||||||
|
mut self,
|
||||||
|
sample: usize,
|
||||||
|
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||||
|
) -> Self {
|
||||||
|
self.dynamic_context
|
||||||
|
.push((sample, Box::new(dynamic_context)));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||||
|
/// dynamic toolset will be inserted in the request.
|
||||||
|
pub fn dynamic_tools(
|
||||||
|
mut self,
|
||||||
|
sample: usize,
|
||||||
|
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||||
|
toolset: ToolSet,
|
||||||
|
) -> Self {
|
||||||
|
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||||
|
self.tools.add_tools(toolset);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the temperature of the model
|
||||||
|
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||||
|
self.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the maximum number of tokens for the completion
|
||||||
|
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||||
|
self.max_tokens = Some(max_tokens);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set additional parameters to be passed to the model
|
||||||
|
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||||
|
self.additional_params = Some(params);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the agent
|
||||||
|
pub fn build(self) -> Agent<M> {
|
||||||
|
Agent {
|
||||||
|
model: self.model,
|
||||||
|
preamble: self.preamble.unwrap_or_default(),
|
||||||
|
static_context: self.static_context,
|
||||||
|
static_tools: self.static_tools,
|
||||||
|
temperature: self.temperature,
|
||||||
|
max_tokens: self.max_tokens,
|
||||||
|
additional_params: self.additional_params,
|
||||||
|
dynamic_context: self.dynamic_context,
|
||||||
|
dynamic_tools: self.dynamic_tools,
|
||||||
|
tools: self.tools,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||||
|
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||||
|
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||||
|
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||||
|
//! context documents and tools.
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
||||||
|
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
||||||
|
//! be used for generating chat completions.
|
||||||
|
//!
|
||||||
|
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||||
|
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||||
|
//! before building the agent.
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//! ```rust
|
||||||
|
//! use rig::{
|
||||||
|
//! completion::{Chat, Completion, Prompt},
|
||||||
|
//! providers::openai,
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! let openai = openai::Client::from_env();
|
||||||
|
//!
|
||||||
|
//! // Configure the agent
|
||||||
|
//! let agent = openai.agent("gpt-4o")
|
||||||
|
//! .preamble("System prompt")
|
||||||
|
//! .context("Context document 1")
|
||||||
|
//! .context("Context document 2")
|
||||||
|
//! .tool(tool1)
|
||||||
|
//! .tool(tool2)
|
||||||
|
//! .temperature(0.8)
|
||||||
|
//! .additional_params(json!({"foo": "bar"}))
|
||||||
|
//! .build();
|
||||||
|
//!
|
||||||
|
//! // Use the agent for completions and prompts
|
||||||
|
//! // Generate a chat completion response from a prompt and chat history
|
||||||
|
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to chat with Agent");
|
||||||
|
//!
|
||||||
|
//! // Generate a prompt completion response from a simple prompt
|
||||||
|
//! let chat_response = agent.prompt("Prompt")
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to prompt the Agent");
|
||||||
|
//!
|
||||||
|
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||||
|
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||||
|
//! // model parameters, etc.), but these can be overwritten.
|
||||||
|
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to create completion request builder");
|
||||||
|
//!
|
||||||
|
//! let response = completion_req_builder
|
||||||
|
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||||
|
//! .send()
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to send completion request");
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! RAG Agent example
|
||||||
|
//! ```rust
|
||||||
|
//! use rig::{
|
||||||
|
//! completion::Prompt,
|
||||||
|
//! embeddings::EmbeddingsBuilder,
|
||||||
|
//! providers::openai,
|
||||||
|
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! // Initialize OpenAI client
|
||||||
|
//! let openai = openai::Client::from_env();
|
||||||
|
//!
|
||||||
|
//! // Initialize OpenAI embedding model
|
||||||
|
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||||
|
//!
|
||||||
|
//! // Create vector store, compute embeddings and load them in the store
|
||||||
|
//! let mut vector_store = InMemoryVectorStore::default();
|
||||||
|
//!
|
||||||
|
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||||
|
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||||
|
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||||
|
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||||
|
//! .build()
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to build embeddings");
|
||||||
|
//!
|
||||||
|
//! vector_store.add_documents(embeddings)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to add documents");
|
||||||
|
//!
|
||||||
|
//! // Create vector store index
|
||||||
|
//! let index = vector_store.index(embedding_model);
|
||||||
|
//!
|
||||||
|
//! let agent = openai.agent(openai::GPT_4O)
|
||||||
|
//! .preamble("
|
||||||
|
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||||
|
//! You will find additional non-standard word definitions that could be useful below.
|
||||||
|
//! ")
|
||||||
|
//! .dynamic_context(1, index)
|
||||||
|
//! .build();
|
||||||
|
//!
|
||||||
|
//! // Prompt the agent and print the response
|
||||||
|
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||||
|
//! .expect("Failed to prompt the agent");
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
mod agent;
|
||||||
|
mod builder;
|
||||||
|
mod prompt_request;
|
||||||
|
|
||||||
|
pub use agent::Agent;
|
||||||
|
pub use builder::AgentBuilder;
|
||||||
|
pub use prompt_request::PromptRequest;
|
|
@ -0,0 +1,232 @@
|
||||||
|
use std::{
|
||||||
|
future::{Future, IntoFuture},
|
||||||
|
marker::PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
|
||||||
|
message::{AssistantContent, UserContent},
|
||||||
|
tool::ToolSetError,
|
||||||
|
OneOrMany,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::Agent;
|
||||||
|
|
||||||
|
pub trait State {}
|
||||||
|
pub struct Simple;
|
||||||
|
pub struct MultiTurn;
|
||||||
|
|
||||||
|
impl State for Simple {}
|
||||||
|
impl State for MultiTurn {}
|
||||||
|
|
||||||
|
pub trait SendPromptRequest<M: CompletionModel, T: State> {
|
||||||
|
fn send(self) -> impl Future<Output = Result<String, PromptError>> + Send;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A builder for creating prompt requests with customizable options.
|
||||||
|
/// Uses generics to track which options have been set during the build process.
|
||||||
|
pub struct PromptRequest<'c, 'a, M: CompletionModel, T: State> {
|
||||||
|
/// The prompt message to send to the model
|
||||||
|
prompt: Message,
|
||||||
|
/// Optional chat history to include with the prompt
|
||||||
|
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
||||||
|
chat_history: Option<&'c mut Vec<Message>>,
|
||||||
|
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
||||||
|
max_depth: usize,
|
||||||
|
/// The agent to use for execution
|
||||||
|
agent: &'a Agent<M>,
|
||||||
|
|
||||||
|
/// Typestate
|
||||||
|
_state: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
|
||||||
|
/// Create a new PromptRequest with the given prompt and model
|
||||||
|
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
|
||||||
|
Self {
|
||||||
|
prompt: prompt.into(),
|
||||||
|
chat_history: None,
|
||||||
|
max_depth: 0,
|
||||||
|
agent,
|
||||||
|
_state: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
|
||||||
|
/// Set the maximum depth for multi-turn conversations
|
||||||
|
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M, MultiTurn> {
|
||||||
|
PromptRequest {
|
||||||
|
prompt: self.prompt,
|
||||||
|
chat_history: self.chat_history,
|
||||||
|
max_depth: depth,
|
||||||
|
agent: self.agent,
|
||||||
|
_state: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add chat history to the prompt request
|
||||||
|
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M, Simple> {
|
||||||
|
PromptRequest {
|
||||||
|
prompt: self.prompt,
|
||||||
|
chat_history: Some(history),
|
||||||
|
max_depth: self.max_depth,
|
||||||
|
agent: self.agent,
|
||||||
|
_state: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
|
||||||
|
/// future, we should be able to use `impl Future<...>` directly via the associated type.
|
||||||
|
///
|
||||||
|
/// Ref: https://github.com/rust-lang/rust/issues/63063
|
||||||
|
impl<'c: 'a, 'a, M: CompletionModel, T: State + 'a> IntoFuture for PromptRequest<'c, 'a, M, T>
|
||||||
|
where
|
||||||
|
PromptRequest<'c, 'a, M, T>: SendPromptRequest<M, T>,
|
||||||
|
{
|
||||||
|
type Output = Result<String, PromptError>;
|
||||||
|
type IntoFuture = BoxFuture<'a, Self::Output>;
|
||||||
|
|
||||||
|
fn into_future(self) -> Self::IntoFuture {
|
||||||
|
self.send().boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, MultiTurn> {
|
||||||
|
async fn send(self) -> Result<String, PromptError> {
|
||||||
|
let agent = self.agent;
|
||||||
|
let mut prompt = self.prompt;
|
||||||
|
let chat_history = if let Some(history) = self.chat_history {
|
||||||
|
history
|
||||||
|
} else {
|
||||||
|
&mut Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut current_max_depth = 0;
|
||||||
|
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
|
||||||
|
while current_max_depth <= self.max_depth + 1 {
|
||||||
|
current_max_depth += 1;
|
||||||
|
|
||||||
|
if self.max_depth > 1 {
|
||||||
|
tracing::info!(
|
||||||
|
"Current conversation depth: {}/{}",
|
||||||
|
current_max_depth,
|
||||||
|
self.max_depth
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = agent
|
||||||
|
.completion(prompt.clone(), chat_history.to_vec())
|
||||||
|
.await?
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
chat_history.push(prompt);
|
||||||
|
|
||||||
|
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
||||||
|
.choice
|
||||||
|
.iter()
|
||||||
|
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
||||||
|
|
||||||
|
chat_history.push(Message::Assistant {
|
||||||
|
content: resp.choice.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
if tool_calls.is_empty() {
|
||||||
|
let merged_texts = texts
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|content| {
|
||||||
|
if let AssistantContent::Text(text) = content {
|
||||||
|
Some(text.text.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
if self.max_depth > 1 {
|
||||||
|
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
||||||
|
return Ok(merged_texts);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool_content = stream::iter(tool_calls)
|
||||||
|
.then(async |choice| {
|
||||||
|
if let AssistantContent::ToolCall(tool_call) = choice {
|
||||||
|
let output = agent
|
||||||
|
.tools
|
||||||
|
.call(
|
||||||
|
&tool_call.function.name,
|
||||||
|
tool_call.function.arguments.to_string(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(UserContent::tool_result(
|
||||||
|
tool_call.id.clone(),
|
||||||
|
OneOrMany::one(output.into()),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"This should never happen as we already filtered for `ToolCall`"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
prompt = Message::User {
|
||||||
|
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
||||||
|
Err(PromptError::MaxDepthError {
|
||||||
|
max_depth: self.max_depth,
|
||||||
|
chat_history: chat_history.clone(),
|
||||||
|
prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, Simple> {
|
||||||
|
async fn send(self) -> Result<String, PromptError> {
|
||||||
|
let chat_history = if let Some(history) = self.chat_history {
|
||||||
|
history.clone()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.agent
|
||||||
|
.completion(self.prompt, chat_history)
|
||||||
|
.await?
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tracing::debug!(?resp.choice);
|
||||||
|
|
||||||
|
if resp.choice.len() > 1 {
|
||||||
|
tracing::warn!("Parallel tool calls are only available when using multi turn. Use `agent.prompt(...).multi_turn(depth).await`!");
|
||||||
|
}
|
||||||
|
|
||||||
|
match resp.choice.first() {
|
||||||
|
AssistantContent::Text(text) => Ok(text.text.clone()),
|
||||||
|
AssistantContent::ToolCall(tool_call) => Ok(self
|
||||||
|
.agent
|
||||||
|
.tools
|
||||||
|
.call(
|
||||||
|
&tool_call.function.name,
|
||||||
|
tool_call.function.arguments.to_string(),
|
||||||
|
)
|
||||||
|
.await?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -477,6 +477,12 @@ impl From<String> for Text {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&String> for Text {
|
||||||
|
fn from(text: &String) -> Self {
|
||||||
|
text.to_owned().into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&str> for Text {
|
impl From<&str> for Text {
|
||||||
fn from(text: &str) -> Self {
|
fn from(text: &str) -> Self {
|
||||||
text.to_owned().into()
|
text.to_owned().into()
|
||||||
|
@ -507,6 +513,14 @@ impl From<&str> for Message {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&String> for Message {
|
||||||
|
fn from(text: &String) -> Self {
|
||||||
|
Message::User {
|
||||||
|
content: OneOrMany::one(UserContent::Text(text.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Text> for Message {
|
impl From<Text> for Message {
|
||||||
fn from(text: Text) -> Self {
|
fn from(text: Text) -> Self {
|
||||||
Message::User {
|
Message::User {
|
||||||
|
|
|
@ -37,6 +37,7 @@ use serde_json::json;
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::{Agent, AgentBuilder},
|
agent::{Agent, AgentBuilder},
|
||||||
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
|
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||||
|
message::Message,
|
||||||
tool::Tool,
|
tool::Tool,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -62,7 +63,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
||||||
where
|
where
|
||||||
M: Sync,
|
M: Sync,
|
||||||
{
|
{
|
||||||
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
|
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
|
||||||
let summary = self.agent.prompt(text).await?;
|
let summary = self.agent.prompt(text).await?;
|
||||||
|
|
||||||
if summary.is_empty() {
|
if summary.is_empty() {
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
use std::future::IntoFuture;
|
use std::future::IntoFuture;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{self, CompletionModel},
|
completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, message::Message, vector_store
|
||||||
extractor::{ExtractionError, Extractor},
|
|
||||||
vector_store,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::Op;
|
use super::Op;
|
||||||
|
@ -129,13 +127,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
|
||||||
where
|
where
|
||||||
M: CompletionModel,
|
M: CompletionModel,
|
||||||
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
||||||
Input: Into<String> + Send + Sync,
|
Input: Into<Message> + Send + Sync,
|
||||||
{
|
{
|
||||||
type Input = Input;
|
type Input = Input;
|
||||||
type Output = Result<Output, ExtractionError>;
|
type Output = Result<Output, ExtractionError>;
|
||||||
|
|
||||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
async fn call(&self, input: Self::Input) -> Self::Output {
|
||||||
self.extractor.extract(&input.into()).await
|
self.extractor.extract(input).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue