From 44aa8e9f13e65e3f629e2cb5744b32226ecd2a3e Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Thu, 10 Apr 2025 18:04:08 -0700 Subject: [PATCH] refactor: blow up agent and introduce typestate --- rig-core/examples/reasoning_loop.rs | 254 ++++++++++ rig-core/src/agent.rs | 684 --------------------------- rig-core/src/agent/agent.rs | 251 ++++++++++ rig-core/src/agent/builder.rs | 179 +++++++ rig-core/src/agent/mod.rs | 116 +++++ rig-core/src/agent/prompt_request.rs | 232 +++++++++ rig-core/src/completion/message.rs | 14 + rig-core/src/extractor.rs | 3 +- rig-core/src/pipeline/agent_ops.rs | 8 +- 9 files changed, 1051 insertions(+), 690 deletions(-) create mode 100644 rig-core/examples/reasoning_loop.rs delete mode 100644 rig-core/src/agent.rs create mode 100644 rig-core/src/agent/agent.rs create mode 100644 rig-core/src/agent/builder.rs create mode 100644 rig-core/src/agent/mod.rs create mode 100644 rig-core/src/agent/prompt_request.rs diff --git a/rig-core/examples/reasoning_loop.rs b/rig-core/examples/reasoning_loop.rs new file mode 100644 index 0000000..8ad7397 --- /dev/null +++ b/rig-core/examples/reasoning_loop.rs @@ -0,0 +1,254 @@ +use rig::{ + agent::Agent, + completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition}, + extractor::Extractor, + message::Message, + providers::anthropic, + tool::Tool, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +const CHAIN_OF_THOUGHT_PROMPT: &str = " +You are an assistant that extracts reasoning steps from a given prompt. +Do not return text, only return a tool call. +"; + +#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] +struct ChainOfThoughtSteps { + steps: Vec, +} + +struct ReasoningAgent { + chain_of_thought_extractor: Extractor, + executor: Agent, +} + +impl Prompt for ReasoningAgent { + #[allow(refining_impl_trait)] + async fn prompt(&self, prompt: impl Into + Send) -> Result { + let prompt: Message = prompt.into(); + let mut chat_history = vec![prompt.clone()]; + let extracted = self + .chain_of_thought_extractor + .extract(prompt) + .await + .map_err(|e| { + tracing::error!("Extraction error: {:?}", e); + CompletionError::ProviderError("".into()) + })?; + + if extracted.steps.is_empty() { + return Ok("No reasoning steps provided.".into()); + } + + let mut reasoning_prompt = String::new(); + for (i, step) in extracted.steps.iter().enumerate() { + reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step)); + } + + let response = self + .executor + .prompt(reasoning_prompt.as_str()) + .with_history(&mut chat_history) + .multi_turn(20) + .await?; + + Ok(response) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_target(false) + .init(); + + // Create OpenAI client + let openai_client = anthropic::Client::from_env(); + + let agent = ReasoningAgent { + chain_of_thought_extractor: openai_client + .extractor(anthropic::CLAUDE_3_5_SONNET) + .preamble(CHAIN_OF_THOUGHT_PROMPT) + .build(), + + executor: openai_client + .agent(anthropic::CLAUDE_3_5_SONNET) + .preamble( + "You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations. + Follow these instructions closely. + 1. Consider the user's request carefully and identify the core elements of the request. + 2. Select which tool among those made available to you is appropriate given the context. + 3. This is very important: never perform the operation yourself. + 4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made. + " + ) + .tool(Add) + .tool(Subtract) + .tool(Multiply) + .tool(Divide) + .build(), + }; + + // Prompt the agent and print the response + let result = agent.prompt("Calculate x for the equation: `20x + 23 = 400x / (1 - x)`").await?; + + println!("\n\nReasoning Agent: {}", result); + + Ok(()) +} + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Add; +impl Tool for Add { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "add", + "description": "Add x and y together", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +struct Multiply; +impl Tool for Multiply { + const NAME: &'static str = "multiply"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "multiply", + "description": "Compute the product of x and y (i.e.: x * y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first factor in the product" + }, + "y": { + "type": "number", + "description": "The second factor in the product" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x * args.y; + Ok(result) + } +} + +struct Divide; +impl Tool for Divide { + const NAME: &'static str = "divide"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "divide", + "description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The Dividend of the division. The number being divided" + }, + "y": { + "type": "number", + "description": "The Divisor of the division. The number by which the dividend is being divided" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x / args.y; + Ok(result) + } +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs deleted file mode 100644 index 3633d05..0000000 --- a/rig-core/src/agent.rs +++ /dev/null @@ -1,684 +0,0 @@ -//! This module contains the implementation of the [Agent] struct and its builder. -//! -//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt), -//! a set of context documents, and a set of tools. Note: both context documents and tools can be either -//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time). -//! -//! The [Agent] struct is highly configurable, allowing the user to define anything from -//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic -//! context documents and tools. -//! -//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating -//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to -//! be used for generating chat completions. -//! -//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent]. -//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters -//! before building the agent. -//! -//! # Example -//! ```rust -//! use rig::{ -//! completion::{Chat, Completion, Prompt}, -//! providers::openai, -//! }; -//! -//! let openai = openai::Client::from_env(); -//! -//! // Configure the agent -//! let agent = openai.agent("gpt-4o") -//! .preamble("System prompt") -//! .context("Context document 1") -//! .context("Context document 2") -//! .tool(tool1) -//! .tool(tool2) -//! .temperature(0.8) -//! .additional_params(json!({"foo": "bar"})) -//! .build(); -//! -//! // Use the agent for completions and prompts -//! // Generate a chat completion response from a prompt and chat history -//! let chat_response = agent.chat("Prompt", chat_history) -//! .await -//! .expect("Failed to chat with Agent"); -//! -//! // Generate a prompt completion response from a simple prompt -//! let chat_response = agent.prompt("Prompt") -//! .await -//! .expect("Failed to prompt the Agent"); -//! -//! // Generate a completion request builder from a prompt and chat history. The builder -//! // will contain the agent's configuration (i.e.: preamble, context documents, tools, -//! // model parameters, etc.), but these can be overwritten. -//! let completion_req_builder = agent.completion("Prompt", chat_history) -//! .await -//! .expect("Failed to create completion request builder"); -//! -//! let response = completion_req_builder -//! .temperature(0.9) // Overwrite the agent's temperature -//! .send() -//! .await -//! .expect("Failed to send completion request"); -//! ``` -//! -//! RAG Agent example -//! ```rust -//! use rig::{ -//! completion::Prompt, -//! embeddings::EmbeddingsBuilder, -//! providers::openai, -//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, -//! }; -//! -//! // Initialize OpenAI client -//! let openai = openai::Client::from_env(); -//! -//! // Initialize OpenAI embedding model -//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); -//! -//! // Create vector store, compute embeddings and load them in the store -//! let mut vector_store = InMemoryVectorStore::default(); -//! -//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) -//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") -//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") -//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") -//! .build() -//! .await -//! .expect("Failed to build embeddings"); -//! -//! vector_store.add_documents(embeddings) -//! .await -//! .expect("Failed to add documents"); -//! -//! // Create vector store index -//! let index = vector_store.index(embedding_model); -//! -//! let agent = openai.agent(openai::GPT_4O) -//! .preamble(" -//! You are a dictionary assistant here to assist the user in understanding the meaning of words. -//! You will find additional non-standard word definitions that could be useful below. -//! ") -//! .dynamic_context(1, index) -//! .build(); -//! -//! // Prompt the agent and print the response -//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await -//! .expect("Failed to prompt the agent"); -//! ``` -use std::{collections::HashMap, future::IntoFuture}; - -use futures::{future::BoxFuture, stream, FutureExt, StreamExt, TryStreamExt}; - -use crate::{ - completion::{ - Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, - Message, Prompt, PromptError, - }, - message::{AssistantContent, UserContent}, - streaming::{ - StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, - StreamingResult, - }, - tool::{Tool, ToolSet, ToolSetError}, - vector_store::{VectorStoreError, VectorStoreIndexDyn}, - OneOrMany, -}; - -#[cfg(feature = "mcp")] -use crate::tool::McpTool; - -/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble -/// (i.e.: system prompt) and a static set of context documents and tools. -/// All context documents and tools are always provided to the agent when prompted. -/// -/// # Example -/// ``` -/// use rig::{completion::Prompt, providers::openai}; -/// -/// let openai = openai::Client::from_env(); -/// -/// let comedian_agent = openai -/// .agent("gpt-4o") -/// .preamble("You are a comedian here to entertain the user using humour and jokes.") -/// .temperature(0.9) -/// .build(); -/// -/// let response = comedian_agent.prompt("Entertain me!") -/// .await -/// .expect("Failed to prompt the agent"); -/// ``` -pub struct Agent { - /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) - model: M, - /// System prompt - preamble: String, - /// Context documents always available to the agent - static_context: Vec, - /// Tools that are always available to the agent (identified by their name) - static_tools: Vec, - /// Temperature of the model - temperature: Option, - /// Maximum number of tokens for the completion - max_tokens: Option, - /// Additional parameters to be passed to the model - additional_params: Option, - /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box)>, - /// Dynamic tools - dynamic_tools: Vec<(usize, Box)>, - /// Actual tool implementations - pub tools: ToolSet, -} - -impl Completion for Agent { - async fn completion( - &self, - prompt: impl Into + Send, - chat_history: Vec, - ) -> Result, CompletionError> { - let prompt = prompt.into(); - let rag_text = prompt.rag_text().clone(); - - let completion_request = self - .model - .completion_request(prompt) - .preamble(self.preamble.clone()) - .messages(chat_history) - .temperature_opt(self.temperature) - .max_tokens_opt(self.max_tokens) - .additional_params_opt(self.additional_params.clone()) - .documents(self.static_context.clone()); - - let agent = match &rag_text { - Some(text) => { - let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n(text, *num_sample) - .await? - .into_iter() - .map(|(_, id, doc)| { - // Pretty print the document if possible for better readability - let text = serde_json::to_string_pretty(&doc) - .unwrap_or_else(|_| doc.to_string()); - - Document { - id, - text, - additional_props: HashMap::new(), - } - }) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - acc.extend(docs); - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n_ids(text, *num_sample) - .await? - .into_iter() - .map(|(_, id)| id) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - for doc in docs { - if let Some(tool) = self.tools.get(&doc) { - acc.push(tool.definition(text.into()).await) - } else { - tracing::warn!("Tool implementation not found in toolset: {}", doc); - } - } - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let static_tools = stream::iter(self.static_tools.iter()) - .filter_map(|toolname| async move { - if let Some(tool) = self.tools.get(toolname) { - Some(tool.definition(text.into()).await) - } else { - tracing::warn!( - "Tool implementation not found in toolset: {}", - toolname - ); - None - } - }) - .collect::>() - .await; - - completion_request - .documents(dynamic_context) - .tools([static_tools.clone(), dynamic_tools].concat()) - } - None => { - let static_tools = stream::iter(self.static_tools.iter()) - .filter_map(|toolname| async move { - if let Some(tool) = self.tools.get(toolname) { - // TODO: tool definitions should likely take an `Option` - Some(tool.definition("".into()).await) - } else { - tracing::warn!( - "Tool implementation not found in toolset: {}", - toolname - ); - None - } - }) - .collect::>() - .await; - - completion_request.tools(static_tools) - } - }; - - Ok(agent) - } -} - -// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque -// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific -// `PromptRequest` implementation for `Agent`, making the builder's usage fluent. -// -// References: -// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait) - -#[allow(refining_impl_trait)] -impl Prompt for Agent { - fn prompt(&self, prompt: impl Into + Send) -> PromptRequest { - PromptRequest::new(self, prompt) - } -} - -#[allow(refining_impl_trait)] -impl Prompt for &Agent { - fn prompt(&self, prompt: impl Into + Send) -> PromptRequest { - PromptRequest::new(*self, prompt) - } -} - -#[allow(refining_impl_trait)] -impl Chat for Agent { - async fn chat( - &self, - prompt: impl Into + Send, - chat_history: Vec, - ) -> Result { - let mut cloned_history = chat_history.clone(); - PromptRequest::new(self, prompt) - .with_history(&mut cloned_history) - .await - } -} - -/// A builder for creating an agent -/// -/// # Example -/// ``` -/// use rig::{providers::openai, agent::AgentBuilder}; -/// -/// let openai = openai::Client::from_env(); -/// -/// let gpt4o = openai.completion_model("gpt-4o"); -/// -/// // Configure the agent -/// let agent = AgentBuilder::new(model) -/// .preamble("System prompt") -/// .context("Context document 1") -/// .context("Context document 2") -/// .tool(tool1) -/// .tool(tool2) -/// .temperature(0.8) -/// .additional_params(json!({"foo": "bar"})) -/// .build(); -/// ``` -pub struct AgentBuilder { - /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) - model: M, - /// System prompt - preamble: Option, - /// Context documents always available to the agent - static_context: Vec, - /// Tools that are always available to the agent (by name) - static_tools: Vec, - /// Additional parameters to be passed to the model - additional_params: Option, - /// Maximum number of tokens for the completion - max_tokens: Option, - /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box)>, - /// Dynamic tools - dynamic_tools: Vec<(usize, Box)>, - /// Temperature of the model - temperature: Option, - /// Actual tool implementations - tools: ToolSet, -} - -impl AgentBuilder { - pub fn new(model: M) -> Self { - Self { - model, - preamble: None, - static_context: vec![], - static_tools: vec![], - temperature: None, - max_tokens: None, - additional_params: None, - dynamic_context: vec![], - dynamic_tools: vec![], - tools: ToolSet::default(), - } - } - - /// Set the system prompt - pub fn preamble(mut self, preamble: &str) -> Self { - self.preamble = Some(preamble.into()); - self - } - - /// Append to the preamble of the agent - pub fn append_preamble(mut self, doc: &str) -> Self { - self.preamble = Some(format!( - "{}\n{}", - self.preamble.unwrap_or_else(|| "".into()), - doc - )); - self - } - - /// Add a static context document to the agent - pub fn context(mut self, doc: &str) -> Self { - self.static_context.push(Document { - id: format!("static_doc_{}", self.static_context.len()), - text: doc.into(), - additional_props: HashMap::new(), - }); - self - } - - /// Add a static tool to the agent - pub fn tool(mut self, tool: impl Tool + 'static) -> Self { - let toolname = tool.name(); - self.tools.add_tool(tool); - self.static_tools.push(toolname); - self - } - - // Add an MCP tool to the agent - #[cfg(feature = "mcp")] - pub fn mcp_tool( - mut self, - tool: mcp_core::types::Tool, - client: mcp_core::client::Client, - ) -> Self { - let toolname = tool.name.clone(); - self.tools.add_tool(McpTool::from_mcp_server(tool, client)); - self.static_tools.push(toolname); - self - } - - /// Add some dynamic context to the agent. On each prompt, `sample` documents from the - /// dynamic context will be inserted in the request. - pub fn dynamic_context( - mut self, - sample: usize, - dynamic_context: impl VectorStoreIndexDyn + 'static, - ) -> Self { - self.dynamic_context - .push((sample, Box::new(dynamic_context))); - self - } - - /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the - /// dynamic toolset will be inserted in the request. - pub fn dynamic_tools( - mut self, - sample: usize, - dynamic_tools: impl VectorStoreIndexDyn + 'static, - toolset: ToolSet, - ) -> Self { - self.dynamic_tools.push((sample, Box::new(dynamic_tools))); - self.tools.add_tools(toolset); - self - } - - /// Set the temperature of the model - pub fn temperature(mut self, temperature: f64) -> Self { - self.temperature = Some(temperature); - self - } - - /// Set the maximum number of tokens for the completion - pub fn max_tokens(mut self, max_tokens: u64) -> Self { - self.max_tokens = Some(max_tokens); - self - } - - /// Set additional parameters to be passed to the model - pub fn additional_params(mut self, params: serde_json::Value) -> Self { - self.additional_params = Some(params); - self - } - - /// Build the agent - pub fn build(self) -> Agent { - Agent { - model: self.model, - preamble: self.preamble.unwrap_or_default(), - static_context: self.static_context, - static_tools: self.static_tools, - temperature: self.temperature, - max_tokens: self.max_tokens, - additional_params: self.additional_params, - dynamic_context: self.dynamic_context, - dynamic_tools: self.dynamic_tools, - tools: self.tools, - } - } -} - -/// A builder for creating prompt requests with customizable options. -/// Uses generics to track which options have been set during the build process. -pub struct PromptRequest<'c, 'a, M: CompletionModel> { - /// The prompt message to send to the model - prompt: Message, - /// Optional chat history to include with the prompt - /// Note: chat history needs to outlive the agent as it might be used with other agents - chat_history: Option<&'c mut Vec>, - /// Maximum depth for multi-turn conversations (0 means no multi-turn) - max_depth: usize, - /// The agent to use for execution - agent: &'a Agent, -} - -impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> { - /// Create a new PromptRequest with the given prompt and model - pub fn new(agent: &'c Agent, prompt: impl Into) -> Self { - Self { - prompt: prompt.into(), - chat_history: None, - max_depth: 0, - agent, - } - } -} - -impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M> { - /// Set the maximum depth for multi-turn conversations - pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M> { - PromptRequest { - prompt: self.prompt, - chat_history: self.chat_history, - max_depth: depth, - agent: self.agent, - } - } - - /// Add chat history to the prompt request - pub fn with_history(self, history: &'c mut Vec) -> PromptRequest<'c, 'a, M> { - PromptRequest { - prompt: self.prompt, - chat_history: Some(history), - max_depth: self.max_depth, - agent: self.agent, - } - } -} - -/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the -/// future, we should be able to use `impl Future<...>` directly via the associated type. -/// -/// Ref: https://github.com/rust-lang/rust/issues/63063 -impl<'c: 'a, 'a, M: CompletionModel + 'c> IntoFuture for PromptRequest<'c, 'a, M> { - type Output = Result; - type IntoFuture = BoxFuture<'a, Self::Output>; - - fn into_future(self) -> Self::IntoFuture { - self.send().boxed() - } -} - -// Implementation for Agent -impl PromptRequest<'_, '_, M> { - async fn send(self) -> Result { - let agent = self.agent; - let mut prompt = self.prompt; - let chat_history = if let Some(history) = self.chat_history { - history - } else { - &mut Vec::new() - }; - - let mut current_max_depth = 0; - while current_max_depth <= self.max_depth { - current_max_depth += 1; - - if self.max_depth > 1 { - tracing::info!( - "Current conversation depth: {}/{}", - current_max_depth, - self.max_depth - ); - } - - let resp = agent - .completion(prompt.clone(), chat_history.to_vec()) - .await? - .send() - .await?; - - chat_history.push(prompt); - - let (tool_calls, texts): (Vec<_>, Vec<_>) = resp - .choice - .iter() - .partition(|choice| matches!(choice, AssistantContent::ToolCall(_))); - - chat_history.push(Message::Assistant { - content: resp.choice.clone(), - }); - - if tool_calls.is_empty() { - let merged_texts = texts - .into_iter() - .filter_map(|content| { - if let AssistantContent::Text(text) = content { - Some(text.text.clone()) - } else { - None - } - }) - .collect::>() - .join("\n"); - - if self.max_depth > 1 { - tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth); - } - - // If there are no tool calls, depth is not relevant, we can just return the merged text. - return Ok(merged_texts); - } - - let tool_content = stream::iter(tool_calls) - .then(async |choice| { - if let AssistantContent::ToolCall(tool_call) = choice { - let output = agent - .tools - .call( - &tool_call.function.name, - tool_call.function.arguments.to_string(), - ) - .await?; - Ok(UserContent::tool_result( - tool_call.id.clone(), - OneOrMany::one(output.into()), - )) - } else { - unreachable!( - "This should never happen as we already filtered for `ToolCall`" - ) - } - }) - .collect::>>() - .await - .into_iter() - .collect::, _>>() - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - prompt = Message::User { - content: OneOrMany::many(tool_content).expect("There is atleast one tool call"), - }; - } - - // If we reach here, we never resolved the final tool call. We need to do ... something. - Err(PromptError::MaxDepthError { - max_depth: self.max_depth, - chat_history: chat_history.clone(), - prompt, - }) - } -} - -impl StreamingCompletion for Agent { - async fn stream_completion( - &self, - prompt: impl Into + Send, - chat_history: Vec, - ) -> Result, CompletionError> { - // Reuse the existing completion implementation to build the request - // This ensures streaming and non-streaming use the same request building logic - self.completion(prompt, chat_history).await - } -} - -impl StreamingPrompt for Agent { - async fn stream_prompt(&self, prompt: &str) -> Result { - self.stream_chat(prompt, vec![]).await - } -} - -impl StreamingChat for Agent { - async fn stream_chat( - &self, - prompt: &str, - chat_history: Vec, - ) -> Result { - self.stream_completion(prompt, chat_history) - .await? - .stream() - .await - } -} diff --git a/rig-core/src/agent/agent.rs b/rig-core/src/agent/agent.rs new file mode 100644 index 0000000..abdf2e4 --- /dev/null +++ b/rig-core/src/agent/agent.rs @@ -0,0 +1,251 @@ +use std::collections::HashMap; + +use futures::{stream, StreamExt, TryStreamExt}; + +use crate::{ + completion::{ + Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, + Message, Prompt, PromptError, + }, + streaming::{ + StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, + StreamingResult, + }, + tool::ToolSet, + vector_store::VectorStoreError, +}; + +use super::prompt_request; +use super::prompt_request::PromptRequest; + +/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble +/// (i.e.: system prompt) and a static set of context documents and tools. +/// All context documents and tools are always provided to the agent when prompted. +/// +/// # Example +/// ``` +/// use rig::{completion::Prompt, providers::openai}; +/// +/// let openai = openai::Client::from_env(); +/// +/// let comedian_agent = openai +/// .agent("gpt-4o") +/// .preamble("You are a comedian here to entertain the user using humour and jokes.") +/// .temperature(0.9) +/// .build(); +/// +/// let response = comedian_agent.prompt("Entertain me!") +/// .await +/// .expect("Failed to prompt the agent"); +/// ``` +pub struct Agent { + /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) + pub model: M, + /// System prompt + pub preamble: String, + /// Context documents always available to the agent + pub static_context: Vec, + /// Tools that are always available to the agent (identified by their name) + pub static_tools: Vec, + /// Temperature of the model + pub temperature: Option, + /// Maximum number of tokens for the completion + pub max_tokens: Option, + /// Additional parameters to be passed to the model + pub additional_params: Option, + /// List of vector store, with the sample number + pub dynamic_context: Vec<(usize, Box)>, + /// Dynamic tools + pub dynamic_tools: Vec<(usize, Box)>, + /// Actual tool implementations + pub tools: ToolSet, +} + +impl Completion for Agent { + async fn completion( + &self, + prompt: impl Into + Send, + chat_history: Vec, + ) -> Result, CompletionError> { + let prompt = prompt.into(); + let rag_text = prompt.rag_text().clone(); + + let completion_request = self + .model + .completion_request(prompt) + .preamble(self.preamble.clone()) + .messages(chat_history) + .temperature_opt(self.temperature) + .max_tokens_opt(self.max_tokens) + .additional_params_opt(self.additional_params.clone()) + .documents(self.static_context.clone()); + + let agent = match &rag_text { + Some(text) => { + let dynamic_context = stream::iter(self.dynamic_context.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n(text, *num_sample) + .await? + .into_iter() + .map(|(_, id, doc)| { + // Pretty print the document if possible for better readability + let text = serde_json::to_string_pretty(&doc) + .unwrap_or_else(|_| doc.to_string()); + + Document { + id, + text, + additional_props: HashMap::new(), + } + }) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + acc.extend(docs); + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let dynamic_tools = stream::iter(self.dynamic_tools.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n_ids(text, *num_sample) + .await? + .into_iter() + .map(|(_, id)| id) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + for doc in docs { + if let Some(tool) = self.tools.get(&doc) { + acc.push(tool.definition(text.into()).await) + } else { + tracing::warn!("Tool implementation not found in toolset: {}", doc); + } + } + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let static_tools = stream::iter(self.static_tools.iter()) + .filter_map(|toolname| async move { + if let Some(tool) = self.tools.get(toolname) { + Some(tool.definition(text.into()).await) + } else { + tracing::warn!( + "Tool implementation not found in toolset: {}", + toolname + ); + None + } + }) + .collect::>() + .await; + + completion_request + .documents(dynamic_context) + .tools([static_tools.clone(), dynamic_tools].concat()) + } + None => { + let static_tools = stream::iter(self.static_tools.iter()) + .filter_map(|toolname| async move { + if let Some(tool) = self.tools.get(toolname) { + // TODO: tool definitions should likely take an `Option` + Some(tool.definition("".into()).await) + } else { + tracing::warn!( + "Tool implementation not found in toolset: {}", + toolname + ); + None + } + }) + .collect::>() + .await; + + completion_request.tools(static_tools) + } + }; + + Ok(agent) + } +} + +// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque +// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific +// `PromptRequest` implementation for `Agent`, making the builder's usage fluent. +// +// References: +// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait) + +#[allow(refining_impl_trait)] +impl Prompt for Agent { + fn prompt( + &self, + prompt: impl Into + Send, + ) -> PromptRequest { + PromptRequest::new(self, prompt) + } +} + +#[allow(refining_impl_trait)] +impl Prompt for &Agent { + fn prompt( + &self, + prompt: impl Into + Send, + ) -> PromptRequest { + PromptRequest::new(*self, prompt) + } +} + +#[allow(refining_impl_trait)] +impl Chat for Agent { + async fn chat( + &self, + prompt: impl Into + Send, + chat_history: Vec, + ) -> Result { + let mut cloned_history = chat_history.clone(); + PromptRequest::new(self, prompt) + .with_history(&mut cloned_history) + .await + } +} + +impl StreamingCompletion for Agent { + async fn stream_completion( + &self, + prompt: impl Into + Send, + chat_history: Vec, + ) -> Result, CompletionError> { + // Reuse the existing completion implementation to build the request + // This ensures streaming and non-streaming use the same request building logic + self.completion(prompt, chat_history).await + } +} + +impl StreamingPrompt for Agent { + async fn stream_prompt(&self, prompt: &str) -> Result { + self.stream_chat(prompt, vec![]).await + } +} + +impl StreamingChat for Agent { + async fn stream_chat( + &self, + prompt: &str, + chat_history: Vec, + ) -> Result { + self.stream_completion(prompt, chat_history) + .await? + .stream() + .await + } +} diff --git a/rig-core/src/agent/builder.rs b/rig-core/src/agent/builder.rs new file mode 100644 index 0000000..8f16092 --- /dev/null +++ b/rig-core/src/agent/builder.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; + +use crate::{ + completion::{CompletionModel, Document}, + tool::{Tool, ToolSet}, + vector_store::VectorStoreIndexDyn, +}; + +#[cfg(feature = "mcp")] +use crate::tool::McpTool; + +use super::Agent; + +/// A builder for creating an agent +/// +/// # Example +/// ``` +/// use rig::{providers::openai, agent::AgentBuilder}; +/// +/// let openai = openai::Client::from_env(); +/// +/// let gpt4o = openai.completion_model("gpt-4o"); +/// +/// // Configure the agent +/// let agent = AgentBuilder::new(model) +/// .preamble("System prompt") +/// .context("Context document 1") +/// .context("Context document 2") +/// .tool(tool1) +/// .tool(tool2) +/// .temperature(0.8) +/// .additional_params(json!({"foo": "bar"})) +/// .build(); +/// ``` +pub struct AgentBuilder { + /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) + model: M, + /// System prompt + preamble: Option, + /// Context documents always available to the agent + static_context: Vec, + /// Tools that are always available to the agent (by name) + static_tools: Vec, + /// Additional parameters to be passed to the model + additional_params: Option, + /// Maximum number of tokens for the completion + max_tokens: Option, + /// List of vector store, with the sample number + dynamic_context: Vec<(usize, Box)>, + /// Dynamic tools + dynamic_tools: Vec<(usize, Box)>, + /// Temperature of the model + temperature: Option, + /// Actual tool implementations + tools: ToolSet, +} + +impl AgentBuilder { + pub fn new(model: M) -> Self { + Self { + model, + preamble: None, + static_context: vec![], + static_tools: vec![], + temperature: None, + max_tokens: None, + additional_params: None, + dynamic_context: vec![], + dynamic_tools: vec![], + tools: ToolSet::default(), + } + } + + /// Set the system prompt + pub fn preamble(mut self, preamble: &str) -> Self { + self.preamble = Some(preamble.into()); + self + } + + /// Append to the preamble of the agent + pub fn append_preamble(mut self, doc: &str) -> Self { + self.preamble = Some(format!( + "{}\n{}", + self.preamble.unwrap_or_else(|| "".into()), + doc + )); + self + } + + /// Add a static context document to the agent + pub fn context(mut self, doc: &str) -> Self { + self.static_context.push(Document { + id: format!("static_doc_{}", self.static_context.len()), + text: doc.into(), + additional_props: HashMap::new(), + }); + self + } + + /// Add a static tool to the agent + pub fn tool(mut self, tool: impl Tool + 'static) -> Self { + let toolname = tool.name(); + self.tools.add_tool(tool); + self.static_tools.push(toolname); + self + } + + // Add an MCP tool to the agent + #[cfg(feature = "mcp")] + pub fn mcp_tool( + mut self, + tool: mcp_core::types::Tool, + client: mcp_core::client::Client, + ) -> Self { + let toolname = tool.name.clone(); + self.tools.add_tool(McpTool::from_mcp_server(tool, client)); + self.static_tools.push(toolname); + self + } + + /// Add some dynamic context to the agent. On each prompt, `sample` documents from the + /// dynamic context will be inserted in the request. + pub fn dynamic_context( + mut self, + sample: usize, + dynamic_context: impl VectorStoreIndexDyn + 'static, + ) -> Self { + self.dynamic_context + .push((sample, Box::new(dynamic_context))); + self + } + + /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the + /// dynamic toolset will be inserted in the request. + pub fn dynamic_tools( + mut self, + sample: usize, + dynamic_tools: impl VectorStoreIndexDyn + 'static, + toolset: ToolSet, + ) -> Self { + self.dynamic_tools.push((sample, Box::new(dynamic_tools))); + self.tools.add_tools(toolset); + self + } + + /// Set the temperature of the model + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + /// Set the maximum number of tokens for the completion + pub fn max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + /// Set additional parameters to be passed to the model + pub fn additional_params(mut self, params: serde_json::Value) -> Self { + self.additional_params = Some(params); + self + } + + /// Build the agent + pub fn build(self) -> Agent { + Agent { + model: self.model, + preamble: self.preamble.unwrap_or_default(), + static_context: self.static_context, + static_tools: self.static_tools, + temperature: self.temperature, + max_tokens: self.max_tokens, + additional_params: self.additional_params, + dynamic_context: self.dynamic_context, + dynamic_tools: self.dynamic_tools, + tools: self.tools, + } + } +} diff --git a/rig-core/src/agent/mod.rs b/rig-core/src/agent/mod.rs new file mode 100644 index 0000000..6dbb6f9 --- /dev/null +++ b/rig-core/src/agent/mod.rs @@ -0,0 +1,116 @@ +//! This module contains the implementation of the [Agent] struct and its builder. +//! +//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt), +//! a set of context documents, and a set of tools. Note: both context documents and tools can be either +//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time). +//! +//! The [Agent] struct is highly configurable, allowing the user to define anything from +//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic +//! context documents and tools. +//! +//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating +//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to +//! be used for generating chat completions. +//! +//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent]. +//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters +//! before building the agent. +//! +//! # Example +//! ```rust +//! use rig::{ +//! completion::{Chat, Completion, Prompt}, +//! providers::openai, +//! }; +//! +//! let openai = openai::Client::from_env(); +//! +//! // Configure the agent +//! let agent = openai.agent("gpt-4o") +//! .preamble("System prompt") +//! .context("Context document 1") +//! .context("Context document 2") +//! .tool(tool1) +//! .tool(tool2) +//! .temperature(0.8) +//! .additional_params(json!({"foo": "bar"})) +//! .build(); +//! +//! // Use the agent for completions and prompts +//! // Generate a chat completion response from a prompt and chat history +//! let chat_response = agent.chat("Prompt", chat_history) +//! .await +//! .expect("Failed to chat with Agent"); +//! +//! // Generate a prompt completion response from a simple prompt +//! let chat_response = agent.prompt("Prompt") +//! .await +//! .expect("Failed to prompt the Agent"); +//! +//! // Generate a completion request builder from a prompt and chat history. The builder +//! // will contain the agent's configuration (i.e.: preamble, context documents, tools, +//! // model parameters, etc.), but these can be overwritten. +//! let completion_req_builder = agent.completion("Prompt", chat_history) +//! .await +//! .expect("Failed to create completion request builder"); +//! +//! let response = completion_req_builder +//! .temperature(0.9) // Overwrite the agent's temperature +//! .send() +//! .await +//! .expect("Failed to send completion request"); +//! ``` +//! +//! RAG Agent example +//! ```rust +//! use rig::{ +//! completion::Prompt, +//! embeddings::EmbeddingsBuilder, +//! providers::openai, +//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, +//! }; +//! +//! // Initialize OpenAI client +//! let openai = openai::Client::from_env(); +//! +//! // Initialize OpenAI embedding model +//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); +//! +//! // Create vector store, compute embeddings and load them in the store +//! let mut vector_store = InMemoryVectorStore::default(); +//! +//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) +//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") +//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") +//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") +//! .build() +//! .await +//! .expect("Failed to build embeddings"); +//! +//! vector_store.add_documents(embeddings) +//! .await +//! .expect("Failed to add documents"); +//! +//! // Create vector store index +//! let index = vector_store.index(embedding_model); +//! +//! let agent = openai.agent(openai::GPT_4O) +//! .preamble(" +//! You are a dictionary assistant here to assist the user in understanding the meaning of words. +//! You will find additional non-standard word definitions that could be useful below. +//! ") +//! .dynamic_context(1, index) +//! .build(); +//! +//! // Prompt the agent and print the response +//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await +//! .expect("Failed to prompt the agent"); +//! ``` + +mod agent; +mod builder; +mod prompt_request; + +pub use agent::Agent; +pub use builder::AgentBuilder; +pub use prompt_request::PromptRequest; diff --git a/rig-core/src/agent/prompt_request.rs b/rig-core/src/agent/prompt_request.rs new file mode 100644 index 0000000..e16af68 --- /dev/null +++ b/rig-core/src/agent/prompt_request.rs @@ -0,0 +1,232 @@ +use std::{ + future::{Future, IntoFuture}, + marker::PhantomData, +}; + +use futures::{future::BoxFuture, stream, FutureExt, StreamExt}; + +use crate::{ + completion::{Completion, CompletionError, CompletionModel, Message, PromptError}, + message::{AssistantContent, UserContent}, + tool::ToolSetError, + OneOrMany, +}; + +use super::Agent; + +pub trait State {} +pub struct Simple; +pub struct MultiTurn; + +impl State for Simple {} +impl State for MultiTurn {} + +pub trait SendPromptRequest { + fn send(self) -> impl Future> + Send; +} + +/// A builder for creating prompt requests with customizable options. +/// Uses generics to track which options have been set during the build process. +pub struct PromptRequest<'c, 'a, M: CompletionModel, T: State> { + /// The prompt message to send to the model + prompt: Message, + /// Optional chat history to include with the prompt + /// Note: chat history needs to outlive the agent as it might be used with other agents + chat_history: Option<&'c mut Vec>, + /// Maximum depth for multi-turn conversations (0 means no multi-turn) + max_depth: usize, + /// The agent to use for execution + agent: &'a Agent, + + /// Typestate + _state: PhantomData, +} + +impl<'c: 'a, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> { + /// Create a new PromptRequest with the given prompt and model + pub fn new(agent: &'c Agent, prompt: impl Into) -> Self { + Self { + prompt: prompt.into(), + chat_history: None, + max_depth: 0, + agent, + _state: PhantomData, + } + } +} + +impl<'c, 'a, M: CompletionModel> PromptRequest<'c, 'a, M, Simple> { + /// Set the maximum depth for multi-turn conversations + pub fn multi_turn(self, depth: usize) -> PromptRequest<'c, 'a, M, MultiTurn> { + PromptRequest { + prompt: self.prompt, + chat_history: self.chat_history, + max_depth: depth, + agent: self.agent, + _state: PhantomData, + } + } + + /// Add chat history to the prompt request + pub fn with_history(self, history: &'c mut Vec) -> PromptRequest<'c, 'a, M, Simple> { + PromptRequest { + prompt: self.prompt, + chat_history: Some(history), + max_depth: self.max_depth, + agent: self.agent, + _state: PhantomData, + } + } +} + +/// Due to: RFC 2515, we have to use a `BoxFuture` for the `IntoFuture` implementation. In the +/// future, we should be able to use `impl Future<...>` directly via the associated type. +/// +/// Ref: https://github.com/rust-lang/rust/issues/63063 +impl<'c: 'a, 'a, M: CompletionModel, T: State + 'a> IntoFuture for PromptRequest<'c, 'a, M, T> +where + PromptRequest<'c, 'a, M, T>: SendPromptRequest, +{ + type Output = Result; + type IntoFuture = BoxFuture<'a, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + self.send().boxed() + } +} + +impl SendPromptRequest for PromptRequest<'_, '_, M, MultiTurn> { + async fn send(self) -> Result { + let agent = self.agent; + let mut prompt = self.prompt; + let chat_history = if let Some(history) = self.chat_history { + history + } else { + &mut Vec::new() + }; + + let mut current_max_depth = 0; + // We need to do atleast 2 loops for 1 roundtrip (user expects normal message) + while current_max_depth <= self.max_depth + 1 { + current_max_depth += 1; + + if self.max_depth > 1 { + tracing::info!( + "Current conversation depth: {}/{}", + current_max_depth, + self.max_depth + ); + } + + let resp = agent + .completion(prompt.clone(), chat_history.to_vec()) + .await? + .send() + .await?; + + chat_history.push(prompt); + + let (tool_calls, texts): (Vec<_>, Vec<_>) = resp + .choice + .iter() + .partition(|choice| matches!(choice, AssistantContent::ToolCall(_))); + + chat_history.push(Message::Assistant { + content: resp.choice.clone(), + }); + + if tool_calls.is_empty() { + let merged_texts = texts + .into_iter() + .filter_map(|content| { + if let AssistantContent::Text(text) = content { + Some(text.text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if self.max_depth > 1 { + tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth); + } + + // If there are no tool calls, depth is not relevant, we can just return the merged text. + return Ok(merged_texts); + } + + let tool_content = stream::iter(tool_calls) + .then(async |choice| { + if let AssistantContent::ToolCall(tool_call) = choice { + let output = agent + .tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await?; + Ok(UserContent::tool_result( + tool_call.id.clone(), + OneOrMany::one(output.into()), + )) + } else { + unreachable!( + "This should never happen as we already filtered for `ToolCall`" + ) + } + }) + .collect::>>() + .await + .into_iter() + .collect::, _>>() + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + prompt = Message::User { + content: OneOrMany::many(tool_content).expect("There is atleast one tool call"), + }; + } + + // If we reach here, we never resolved the final tool call. We need to do ... something. + Err(PromptError::MaxDepthError { + max_depth: self.max_depth, + chat_history: chat_history.clone(), + prompt, + }) + } +} + +impl SendPromptRequest for PromptRequest<'_, '_, M, Simple> { + async fn send(self) -> Result { + let chat_history = if let Some(history) = self.chat_history { + history.clone() + } else { + Vec::new() + }; + + let resp = self + .agent + .completion(self.prompt, chat_history) + .await? + .send() + .await?; + + tracing::debug!(?resp.choice); + + if resp.choice.len() > 1 { + tracing::warn!("Parallel tool calls are only available when using multi turn. Use `agent.prompt(...).multi_turn(depth).await`!"); + } + + match resp.choice.first() { + AssistantContent::Text(text) => Ok(text.text.clone()), + AssistantContent::ToolCall(tool_call) => Ok(self + .agent + .tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await?), + } + } +} diff --git a/rig-core/src/completion/message.rs b/rig-core/src/completion/message.rs index e2c4fd8..c2a89e5 100644 --- a/rig-core/src/completion/message.rs +++ b/rig-core/src/completion/message.rs @@ -477,6 +477,12 @@ impl From for Text { } } +impl From<&String> for Text { + fn from(text: &String) -> Self { + text.to_owned().into() + } +} + impl From<&str> for Text { fn from(text: &str) -> Self { text.to_owned().into() @@ -507,6 +513,14 @@ impl From<&str> for Message { } } +impl From<&String> for Message { + fn from(text: &String) -> Self { + Message::User { + content: OneOrMany::one(UserContent::Text(text.into())), + } + } +} + impl From for Message { fn from(text: Text) -> Self { Message::User { diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index 9b70a14..34e67d6 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -37,6 +37,7 @@ use serde_json::json; use crate::{ agent::{Agent, AgentBuilder}, completion::{CompletionModel, Prompt, PromptError, ToolDefinition}, + message::Message, tool::Tool, }; @@ -62,7 +63,7 @@ impl Deserialize<'a> + Send + Sync, M: CompletionModel> where M: Sync, { - pub async fn extract(&self, text: &str) -> Result { + pub async fn extract(&self, text: impl Into + Send) -> Result { let summary = self.agent.prompt(text).await?; if summary.is_empty() { diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index 0bdb8cc..33298cf 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,9 +1,7 @@ use std::future::IntoFuture; use crate::{ - completion::{self, CompletionModel}, - extractor::{ExtractionError, Extractor}, - vector_store, + completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, message::Message, vector_store }; use super::Op; @@ -129,13 +127,13 @@ impl Op for Extract where M: CompletionModel, Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - Input: Into + Send + Sync, + Input: Into + Send + Sync, { type Input = Input; type Output = Result; async fn call(&self, input: Self::Input) -> Self::Output { - self.extractor.extract(&input.into()).await + self.extractor.extract(input).await } }