refactor: blow up agent and introduce typestate

This commit is contained in:
0xMochan 2025-04-10 18:04:08 -07:00
parent 4b38294264
commit 44aa8e9f13
9 changed files with 1051 additions and 690 deletions

View File

@ -0,0 +1,254 @@
use rig::{
agent::Agent,
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
extractor::Extractor,
message::Message,
providers::anthropic,
tool::Tool,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
const CHAIN_OF_THOUGHT_PROMPT: &str = "
You are an assistant that extracts reasoning steps from a given prompt.
Do not return text, only return a tool call.
";
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
struct ChainOfThoughtSteps {
steps: Vec<String>,
}
struct ReasoningAgent<M: CompletionModel> {
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
executor: Agent<M>,
}
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
#[allow(refining_impl_trait)]
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
let prompt: Message = prompt.into();
let mut chat_history = vec![prompt.clone()];
let extracted = self
.chain_of_thought_extractor
.extract(prompt)
.await
.map_err(|e| {
tracing::error!("Extraction error: {:?}", e);
CompletionError::ProviderError("".into())
})?;
if extracted.steps.is_empty() {
return Ok("No reasoning steps provided.".into());
}
let mut reasoning_prompt = String::new();
for (i, step) in extracted.steps.iter().enumerate() {
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
}
let response = self
.executor
.prompt(reasoning_prompt.as_str())
.with_history(&mut chat_history)
.multi_turn(20)
.await?;
Ok(response)
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();
// Create OpenAI client
let openai_client = anthropic::Client::from_env();
let agent = ReasoningAgent {
chain_of_thought_extractor: openai_client
.extractor(anthropic::CLAUDE_3_5_SONNET)
.preamble(CHAIN_OF_THOUGHT_PROMPT)
.build(),
executor: openai_client
.agent(anthropic::CLAUDE_3_5_SONNET)
.preamble(
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
Follow these instructions closely.
1. Consider the user's request carefully and identify the core elements of the request.
2. Select which tool among those made available to you is appropriate given the context.
3. This is very important: never perform the operation yourself.
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
"
)
.tool(Add)
.tool(Subtract)
.tool(Multiply)
.tool(Divide)
.build(),
};
// Prompt the agent and print the response
let result = agent.prompt("Calculate x for the equation: `20x + 23 = 400x / (1 - x)`").await?;
println!("\n\nReasoning Agent: {}", result);
Ok(())
}
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Add;
impl Tool for Add {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "add",
"description": "Add x and y together",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
struct Multiply;
impl Tool for Multiply {
const NAME: &'static str = "multiply";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "multiply",
"description": "Compute the product of x and y (i.e.: x * y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first factor in the product"
},
"y": {
"type": "number",
"description": "The second factor in the product"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x * args.y;
Ok(result)
}
}
struct Divide;
impl Tool for Divide {
const NAME: &'static str = "divide";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "divide",
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The Dividend of the division. The number being divided"
},
"y": {
"type": "number",
"description": "The Divisor of the division. The number by which the dividend is being divided"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x / args.y;
Ok(result)
}
}

View File

@ -1,684 +0,0 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
//! be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
//! before building the agent.
//!
//! # Example
//! ```rust
//! use rig::{
//! completion::{Chat, Completion, Prompt},
//! providers::openai,
//! };
//!
//! let openai = openai::Client::from_env();
//!
//! // Configure the agent
//! let agent = openai.agent("gpt-4o")
//! .preamble("System prompt")
//! .context("Context document 1")
//! .context("Context document 2")
//! .tool(tool1)
//! .tool(tool2)
//! .temperature(0.8)
//! .additional_params(json!({"foo": "bar"}))
//! .build();
//!
//! // Use the agent for completions and prompts
//! // Generate a chat completion response from a prompt and chat history
//! let chat_response = agent.chat("Prompt", chat_history)
//! .await
//! .expect("Failed to chat with Agent");
//!
//! // Generate a prompt completion response from a simple prompt
//! let chat_response = agent.prompt("Prompt")
//! .await
//! .expect("Failed to prompt the Agent");
//!
//! // Generate a completion request builder from a prompt and chat history. The builder
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
//! // model parameters, etc.), but these can be overwritten.
//! let completion_req_builder = agent.completion("Prompt", chat_history)
//! .await
//! .expect("Failed to create completion request builder");
//!
//! let response = completion_req_builder
//! .temperature(0.9) // Overwrite the agent's temperature
//! .send()
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let agent = openai.agent(openai::GPT_4O)
//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::{collections::HashMap, future::IntoFuture};
use futures::{future::BoxFuture, stream, FutureExt, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::{AssistantContent, UserContent},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet, ToolSetError},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
OneOrMany,
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
///
/// # Example
/// ```
/// use rig::{completion::Prompt, providers::openai};
///
/// let openai = openai::Client::from_env();
///
/// let comedian_agent = openai
/// .agent("gpt-4o")
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
/// .temperature(0.9)
/// .build();
///
/// let response = comedian_agent.prompt("Entertain me!")
/// .await
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: String,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request
.documents(dynamic_context)
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => {
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
// TODO: tool definitions should likely take an `Option<String>`
Some(tool.definition("".into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request.tools(static_tools)
}
};
Ok(agent)
}
}
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
//
// References:
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for Agent<M> {
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
PromptRequest::new(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for &Agent<M> {
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
PromptRequest::new(*self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let mut cloned_history = chat_history.clone();
PromptRequest::new(self, prompt)
.with_history(&mut cloned_history)
.await
}
}
/// A builder for creating an agent
///
/// # Example
/// ```
/// use rig::{providers::openai, agent::AgentBuilder};
///
/// let openai = openai::Client::from_env();
///
/// let gpt4o = openai.completion_model("gpt-4o");
///
/// // Configure the agent
/// let agent = AgentBuilder::new(model)
/// .preamble("System prompt")
/// .context("Context document 1")
/// .context("Context document 2")
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
/// Add a static context document to the agent
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
/// Add a static tool to the agent
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
// Add an MCP tool to the agent
#[cfg(feature = "mcp")]
pub fn mcp_tool<T: mcp_core::transport::Transport>(
mut self,
tool: mcp_core::types::Tool,
client: mcp_core::client::Client<T>,
) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
/// Build the agent
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}
/// A builder for creating prompt requests with customizable options.
/// Uses generics to track which options have been set during the build process.
pub struct PromptRequest<'c, 'a, M: CompletionModel> {
/// The prompt message to send to the model
prompt: Message,
/// Optional chat history to include with the prompt
/// Note: chat history needs to outlive the agent as it might be used with other agents
chat_history: Option<&'c mut Vec<Message>>,
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
max_depth: usize,
/// The agent to use for execution
agent: &'a Agent<M>,
}
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
/// Create a new PromptRequest with the given prompt and model
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
Self {
prompt: prompt.into(),
chat_history: None,
max_depth: 0,
agent,
}
}
}
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> {
/// Set the maximum depth for multi-turn conversations
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: depth,
agent: self.agent,
}
}
/// Add chat history to the prompt request
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M> {
PromptRequest {
prompt: self.prompt,
chat_history: Some(history),
max_depth: self.max_depth,
agent: self.agent,
}
}
}
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
/// future, we should be able to use `impl Future<...>` directly via the associated type.
///
/// Ref: https://github.com/rust-lang/rust/issues/63063
impl<'c: 'a, 'a, M: CompletionModel + 'c> IntoFuture for PromptRequest<'c, 'a, M> {
type Output = Result<String, PromptError>;
type IntoFuture = BoxFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
self.send().boxed()
}
}
// Implementation for Agent
impl<M: CompletionModel> PromptRequest<'_, '_, M> {
async fn send(self) -> Result<String, PromptError> {
let agent = self.agent;
let mut prompt = self.prompt;
let chat_history = if let Some(history) = self.chat_history {
history
} else {
&mut Vec::new()
};
let mut current_max_depth = 0;
while current_max_depth <= self.max_depth {
current_max_depth += 1;
if self.max_depth > 1 {
tracing::info!(
"Current conversation depth: {}/{}",
current_max_depth,
self.max_depth
);
}
let resp = agent
.completion(prompt.clone(), chat_history.to_vec())
.await?
.send()
.await?;
chat_history.push(prompt);
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
chat_history.push(Message::Assistant {
content: resp.choice.clone(),
});
if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
if self.max_depth > 1 {
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
}
// If there are no tool calls, depth is not relevant, we can just return the merged text.
return Ok(merged_texts);
}
let tool_content = stream::iter(tool_calls)
.then(async |choice| {
if let AssistantContent::ToolCall(tool_call) = choice {
let output = agent
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;
Ok(UserContent::tool_result(
tool_call.id.clone(),
OneOrMany::one(output.into()),
))
} else {
unreachable!(
"This should never happen as we already filtered for `ToolCall`"
)
}
})
.collect::<Vec<Result<UserContent, ToolSetError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
prompt = Message::User {
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
};
}
// If we reach here, we never resolved the final tool call. We need to do ... something.
Err(PromptError::MaxDepthError {
max_depth: self.max_depth,
chat_history: chat_history.clone(),
prompt,
})
}
}
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
// Reuse the existing completion implementation to build the request
// This ensures streaming and non-streaming use the same request building logic
self.completion(prompt, chat_history).await
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()
.await
}
}

251
rig-core/src/agent/agent.rs Normal file
View File

@ -0,0 +1,251 @@
use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::ToolSet,
vector_store::VectorStoreError,
};
use super::prompt_request;
use super::prompt_request::PromptRequest;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
///
/// # Example
/// ```
/// use rig::{completion::Prompt, providers::openai};
///
/// let openai = openai::Client::from_env();
///
/// let comedian_agent = openai
/// .agent("gpt-4o")
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
/// .temperature(0.9)
/// .build();
///
/// let response = comedian_agent.prompt("Entertain me!")
/// .await
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
pub model: M,
/// System prompt
pub preamble: String,
/// Context documents always available to the agent
pub static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
pub static_tools: Vec<String>,
/// Temperature of the model
pub temperature: Option<f64>,
/// Maximum number of tokens for the completion
pub max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
pub additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
/// Dynamic tools
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request
.documents(dynamic_context)
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => {
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
// TODO: tool definitions should likely take an `Option<String>`
Some(tool.definition("".into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request.tools(static_tools)
}
};
Ok(agent)
}
}
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
//
// References:
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for Agent<M> {
fn prompt(
&self,
prompt: impl Into<Message> + Send,
) -> PromptRequest<M, prompt_request::Simple> {
PromptRequest::new(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for &Agent<M> {
fn prompt(
&self,
prompt: impl Into<Message> + Send,
) -> PromptRequest<M, prompt_request::Simple> {
PromptRequest::new(*self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let mut cloned_history = chat_history.clone();
PromptRequest::new(self, prompt)
.with_history(&mut cloned_history)
.await
}
}
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
// Reuse the existing completion implementation to build the request
// This ensures streaming and non-streaming use the same request building logic
self.completion(prompt, chat_history).await
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()
.await
}
}

View File

@ -0,0 +1,179 @@
use std::collections::HashMap;
use crate::{
completion::{CompletionModel, Document},
tool::{Tool, ToolSet},
vector_store::VectorStoreIndexDyn,
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use super::Agent;
/// A builder for creating an agent
///
/// # Example
/// ```
/// use rig::{providers::openai, agent::AgentBuilder};
///
/// let openai = openai::Client::from_env();
///
/// let gpt4o = openai.completion_model("gpt-4o");
///
/// // Configure the agent
/// let agent = AgentBuilder::new(model)
/// .preamble("System prompt")
/// .context("Context document 1")
/// .context("Context document 2")
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
/// Add a static context document to the agent
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
/// Add a static tool to the agent
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
// Add an MCP tool to the agent
#[cfg(feature = "mcp")]
pub fn mcp_tool<T: mcp_core::transport::Transport>(
mut self,
tool: mcp_core::types::Tool,
client: mcp_core::client::Client<T>,
) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
/// Build the agent
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}

116
rig-core/src/agent/mod.rs Normal file
View File

@ -0,0 +1,116 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
//! be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
//! before building the agent.
//!
//! # Example
//! ```rust
//! use rig::{
//! completion::{Chat, Completion, Prompt},
//! providers::openai,
//! };
//!
//! let openai = openai::Client::from_env();
//!
//! // Configure the agent
//! let agent = openai.agent("gpt-4o")
//! .preamble("System prompt")
//! .context("Context document 1")
//! .context("Context document 2")
//! .tool(tool1)
//! .tool(tool2)
//! .temperature(0.8)
//! .additional_params(json!({"foo": "bar"}))
//! .build();
//!
//! // Use the agent for completions and prompts
//! // Generate a chat completion response from a prompt and chat history
//! let chat_response = agent.chat("Prompt", chat_history)
//! .await
//! .expect("Failed to chat with Agent");
//!
//! // Generate a prompt completion response from a simple prompt
//! let chat_response = agent.prompt("Prompt")
//! .await
//! .expect("Failed to prompt the Agent");
//!
//! // Generate a completion request builder from a prompt and chat history. The builder
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
//! // model parameters, etc.), but these can be overwritten.
//! let completion_req_builder = agent.completion("Prompt", chat_history)
//! .await
//! .expect("Failed to create completion request builder");
//!
//! let response = completion_req_builder
//! .temperature(0.9) // Overwrite the agent's temperature
//! .send()
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let agent = openai.agent(openai::GPT_4O)
//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
mod agent;
mod builder;
mod prompt_request;
pub use agent::Agent;
pub use builder::AgentBuilder;
pub use prompt_request::PromptRequest;

View File

@ -0,0 +1,232 @@
use std::{
future::{Future, IntoFuture},
marker::PhantomData,
};
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
use crate::{
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
message::{AssistantContent, UserContent},
tool::ToolSetError,
OneOrMany,
};
use super::Agent;
pub trait State {}
pub struct Simple;
pub struct MultiTurn;
impl State for Simple {}
impl State for MultiTurn {}
pub trait SendPromptRequest<M: CompletionModel, T: State> {
fn send(self) -> impl Future<Output = Result<String, PromptError>> + Send;
}
/// A builder for creating prompt requests with customizable options.
/// Uses generics to track which options have been set during the build process.
pub struct PromptRequest<'c, 'a, M: CompletionModel, T: State> {
/// The prompt message to send to the model
prompt: Message,
/// Optional chat history to include with the prompt
/// Note: chat history needs to outlive the agent as it might be used with other agents
chat_history: Option<&'c mut Vec<Message>>,
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
max_depth: usize,
/// The agent to use for execution
agent: &'a Agent<M>,
/// Typestate
_state: PhantomData<T>,
}
impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
/// Create a new PromptRequest with the given prompt and model
pub fn new(agent: &'c Agent<M>, prompt: impl Into<Message>) -> Self {
Self {
prompt: prompt.into(),
chat_history: None,
max_depth: 0,
agent,
_state: PhantomData,
}
}
}
impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> {
/// Set the maximum depth for multi-turn conversations
pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M, MultiTurn> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: depth,
agent: self.agent,
_state: PhantomData,
}
}
/// Add chat history to the prompt request
pub fn with_history(self, history: &'c mut Vec<Message>) -> PromptRequest<'c, 'a, M, Simple> {
PromptRequest {
prompt: self.prompt,
chat_history: Some(history),
max_depth: self.max_depth,
agent: self.agent,
_state: PhantomData,
}
}
}
/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the
/// future, we should be able to use `impl Future<...>` directly via the associated type.
///
/// Ref: https://github.com/rust-lang/rust/issues/63063
impl<'c: 'a, 'a, M: CompletionModel, T: State + 'a> IntoFuture for PromptRequest<'c, 'a, M, T>
where
PromptRequest<'c, 'a, M, T>: SendPromptRequest<M, T>,
{
type Output = Result<String, PromptError>;
type IntoFuture = BoxFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
self.send().boxed()
}
}
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, MultiTurn> {
async fn send(self) -> Result<String, PromptError> {
let agent = self.agent;
let mut prompt = self.prompt;
let chat_history = if let Some(history) = self.chat_history {
history
} else {
&mut Vec::new()
};
let mut current_max_depth = 0;
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
while current_max_depth <= self.max_depth + 1 {
current_max_depth += 1;
if self.max_depth > 1 {
tracing::info!(
"Current conversation depth: {}/{}",
current_max_depth,
self.max_depth
);
}
let resp = agent
.completion(prompt.clone(), chat_history.to_vec())
.await?
.send()
.await?;
chat_history.push(prompt);
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
chat_history.push(Message::Assistant {
content: resp.choice.clone(),
});
if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
if self.max_depth > 1 {
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
}
// If there are no tool calls, depth is not relevant, we can just return the merged text.
return Ok(merged_texts);
}
let tool_content = stream::iter(tool_calls)
.then(async |choice| {
if let AssistantContent::ToolCall(tool_call) = choice {
let output = agent
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;
Ok(UserContent::tool_result(
tool_call.id.clone(),
OneOrMany::one(output.into()),
))
} else {
unreachable!(
"This should never happen as we already filtered for `ToolCall`"
)
}
})
.collect::<Vec<Result<UserContent, ToolSetError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
prompt = Message::User {
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
};
}
// If we reach here, we never resolved the final tool call. We need to do ... something.
Err(PromptError::MaxDepthError {
max_depth: self.max_depth,
chat_history: chat_history.clone(),
prompt,
})
}
}
impl<M: CompletionModel, T: State> SendPromptRequest<M, T> for PromptRequest<'_, '_, M, Simple> {
async fn send(self) -> Result<String, PromptError> {
let chat_history = if let Some(history) = self.chat_history {
history.clone()
} else {
Vec::new()
};
let resp = self
.agent
.completion(self.prompt, chat_history)
.await?
.send()
.await?;
tracing::debug!(?resp.choice);
if resp.choice.len() > 1 {
tracing::warn!("Parallel tool calls are only available when using multi turn. Use `agent.prompt(...).multi_turn(depth).await`!");
}
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
.agent
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
}
}
}

View File

@ -477,6 +477,12 @@ impl From<String> for Text {
}
}
impl From<&String> for Text {
fn from(text: &String) -> Self {
text.to_owned().into()
}
}
impl From<&str> for Text {
fn from(text: &str) -> Self {
text.to_owned().into()
@ -507,6 +513,14 @@ impl From<&str> for Message {
}
}
impl From<&String> for Message {
fn from(text: &String) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text.into())),
}
}
}
impl From<Text> for Message {
fn from(text: Text) -> Self {
Message::User {

View File

@ -37,6 +37,7 @@ use serde_json::json;
use crate::{
agent::{Agent, AgentBuilder},
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
message::Message,
tool::Tool,
};
@ -62,7 +63,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
where
M: Sync,
{
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
let summary = self.agent.prompt(text).await?;
if summary.is_empty() {

View File

@ -1,9 +1,7 @@
use std::future::IntoFuture;
use crate::{
completion::{self, CompletionModel},
extractor::{ExtractionError, Extractor},
vector_store,
completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, message::Message, vector_store
};
use super::Op;
@ -129,13 +127,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
where
M: CompletionModel,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Input: Into<String> + Send + Sync,
Input: Into<Message> + Send + Sync,
{
type Input = Input;
type Output = Result<Output, ExtractionError>;
async fn call(&self, input: Self::Input) -> Self::Output {
self.extractor.extract(&input.into()).await
self.extractor.extract(input).await
}
}