This commit is contained in:
Mochan 2025-04-18 14:00:45 -07:00 committed by GitHub
commit 7245a06922
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1449 additions and 951 deletions

View File

@ -120,13 +120,13 @@ impl completion::CompletionModel for CompletionModel {
.model_id(self.model.as_str());
let tool_config = request.tools_config()?;
let prompt_with_history = request.prompt_with_history()?;
let messages = request.messages()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt())
.set_messages(Some(prompt_with_history));
.set_messages(Some(messages));
let response = converse_builder
.send()

View File

@ -28,7 +28,7 @@ impl StreamingCompletionModel for CompletionModel {
.model_id(self.model.as_str());
let tool_config = request.tools_config()?;
let prompt_with_history = request.prompt_with_history()?;
let prompt_with_history = request.messages()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())

View File

@ -6,6 +6,8 @@ use aws_sdk_bedrockruntime::types::{
ToolSpecification,
};
use rig::completion::{CompletionError, Message};
use rig::message::{ContentFormat, DocumentMediaType, UserContent};
use rig::OneOrMany;
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
@ -69,13 +71,30 @@ impl AwsCompletionRequest {
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
}
pub fn prompt_with_history(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
let mut chat_history = self.0.chat_history.to_owned();
let prompt_with_context = self.0.prompt_with_context();
pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
let mut full_history: Vec<Message> = Vec::new();
full_history.append(&mut chat_history);
full_history.push(prompt_with_context);
if !self.0.documents.is_empty() {
let messages = self
.0
.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join(" | ");
let content = OneOrMany::one(UserContent::document(
messages,
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
));
full_history.push(Message::User { content });
}
self.0.chat_history.iter().for_each(|message| {
full_history.push(message.clone());
});
full_history
.into_iter()

View File

@ -11,7 +11,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Create agent with a single context prompt
let comedian_agent = client
.agent("cognitivecomputations/dolphin3.0-mistral-24b:free")
.agent("google/gemini-2.5-pro-exp-03-25:free")
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

View File

@ -3,7 +3,7 @@ use std::env;
use anyhow::Result;
use rig::{
agent::Agent,
completion::Chat,
completion::Prompt,
message::Message,
providers::{cohere, openai},
};
@ -49,20 +49,20 @@ impl Debater {
let resp_a = self
.gpt_4
.chat(prompt_a.as_str(), history_a.clone())
.prompt(prompt_a.as_str())
.with_history(&mut history_a)
.await?;
println!("GPT-4:\n{}", resp_a);
history_a.push(Message::user(prompt_a));
history_a.push(Message::assistant(resp_a.clone()));
println!("================================================================");
let resp_b = self.coral.chat(resp_a.as_str(), history_b.clone()).await?;
let resp_b = self
.coral
.prompt(resp_a.as_str())
.with_history(&mut history_b)
.await?;
println!("Coral:\n{}", resp_b);
println!("================================================================");
history_b.push(Message::user(resp_a));
history_b.push(Message::assistant(resp_b.clone()));
last_resp_b = Some(resp_b)
}

View File

@ -35,6 +35,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
}
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
#[allow(refining_impl_trait)]
async fn chat(
&self,
prompt: impl Into<Message> + Send,

View File

@ -1,99 +1,23 @@
use rig::{
agent::Agent,
completion::{self, Completion, PromptError, ToolDefinition},
message::{AssistantContent, Message, ToolCall, ToolFunction, ToolResultContent, UserContent},
completion::{Prompt, ToolDefinition},
providers::anthropic,
tool::Tool,
OneOrMany,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
struct MultiTurnAgent<M: rig::completion::CompletionModel> {
agent: Agent<M>,
chat_history: Vec<completion::Message>,
}
impl<M: rig::completion::CompletionModel> MultiTurnAgent<M> {
async fn multi_turn_prompt(
&mut self,
prompt: impl Into<Message> + Send,
) -> Result<String, PromptError> {
let mut current_prompt: Message = prompt.into();
loop {
println!("Current Prompt: {:?}\n", current_prompt);
let resp = self
.agent
.completion(current_prompt.clone(), self.chat_history.clone())
.await?
.send()
.await?;
let mut final_text = None;
for content in resp.choice.into_iter() {
match content {
AssistantContent::Text(text) => {
println!("Intermediate Response: {:?}\n", text.text);
final_text = Some(text.text.clone());
self.chat_history.push(current_prompt.clone());
let response_message = Message::Assistant {
content: OneOrMany::one(AssistantContent::text(&text.text)),
};
self.chat_history.push(response_message);
}
AssistantContent::ToolCall(content) => {
self.chat_history.push(current_prompt.clone());
let tool_call_msg = AssistantContent::ToolCall(content.clone());
println!("Tool Call Msg: {:?}\n", tool_call_msg);
self.chat_history.push(Message::Assistant {
content: OneOrMany::one(tool_call_msg),
});
let ToolCall {
id,
function: ToolFunction { name, arguments },
} = content;
let tool_result =
self.agent.tools.call(&name, arguments.to_string()).await?;
current_prompt = Message::User {
content: OneOrMany::one(UserContent::tool_result(
id,
OneOrMany::one(ToolResultContent::text(tool_result)),
)),
};
final_text = None;
break;
}
}
}
if let Some(text) = final_text {
return Ok(text);
}
}
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "stdout=info".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();
// Create OpenAI client
let openai_client = anthropic::Client::from_env();
// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
let agent = openai_client
.agent(anthropic::CLAUDE_3_5_SONNET)
.preamble(
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
@ -109,21 +33,18 @@ async fn main() -> anyhow::Result<()> {
.tool(Divide)
.build();
let mut agent = MultiTurnAgent {
agent: calculator_rag,
chat_history: Vec::new(),
};
// Prompt the agent and print the response
let result = agent
.multi_turn_prompt("Calculate 5 - 2 = ?. Describe the result to me.")
.prompt("Calculate 5 - 2 = ?. Describe the result to me.")
.multi_turn(20)
.await?;
println!("\n\nOpenAI Calculator Agent: {}", result);
// Prompt the agent again and print the response
let result = agent
.multi_turn_prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
.prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
.multi_turn(20)
.await?;
println!("\n\nOpenAI Calculator Agent: {}", result);

View File

@ -0,0 +1,261 @@
use rig::{
agent::Agent,
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
extractor::Extractor,
message::Message,
providers::anthropic,
tool::Tool,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
const CHAIN_OF_THOUGHT_PROMPT: &str = "
You are an assistant that extracts reasoning steps from a given prompt.
Do not return text, only return a tool call.
";
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
struct ChainOfThoughtSteps {
steps: Vec<String>,
}
struct ReasoningAgent<M: CompletionModel> {
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
executor: Agent<M>,
}
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
#[allow(refining_impl_trait)]
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
let prompt: Message = prompt.into();
let mut chat_history = vec![prompt.clone()];
let extracted = self
.chain_of_thought_extractor
.extract(prompt)
.await
.map_err(|e| {
tracing::error!("Extraction error: {:?}", e);
CompletionError::ProviderError("".into())
})?;
if extracted.steps.is_empty() {
return Ok("No reasoning steps provided.".into());
}
let mut reasoning_prompt = String::new();
for (i, step) in extracted.steps.iter().enumerate() {
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
}
let response = self
.executor
.prompt(reasoning_prompt.as_str())
.with_history(&mut chat_history)
.multi_turn(20)
.await?;
tracing::info!(
"full chat history generated: {}",
serde_json::to_string_pretty(&chat_history).unwrap()
);
Ok(response)
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();
// Create OpenAI client
let anthropic_client = anthropic::Client::from_env();
let agent = ReasoningAgent {
chain_of_thought_extractor: anthropic_client
.extractor(anthropic::CLAUDE_3_5_SONNET)
.preamble(CHAIN_OF_THOUGHT_PROMPT)
.build(),
executor: anthropic_client
.agent(anthropic::CLAUDE_3_5_SONNET)
.preamble(
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
Follow these instructions closely.
1. Consider the user's request carefully and identify the core elements of the request.
2. Select which tool among those made available to you is appropriate given the context.
3. This is very important: never perform the operation yourself.
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
"
)
.tool(Add)
.tool(Subtract)
.tool(Multiply)
.tool(Divide)
.build(),
};
// Prompt the agent and print the response
let result = agent
.prompt("Calculate ((15 + 25) * (100 - 50)) / (200 / (10 + 10))")
.await?;
println!("\n\nReasoning Agent Chat History: {}", result);
Ok(())
}
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Add;
impl Tool for Add {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "add",
"description": "Add x and y together",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
struct Multiply;
impl Tool for Multiply {
const NAME: &'static str = "multiply";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "multiply",
"description": "Compute the product of x and y (i.e.: x * y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first factor in the product"
},
"y": {
"type": "number",
"description": "The second factor in the product"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x * args.y;
Ok(result)
}
}
struct Divide;
impl Tool for Divide {
const NAME: &'static str = "divide";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "divide",
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The Dividend of the division. The number being divided"
},
"y": {
"type": "number",
"description": "The Divisor of the division. The number by which the dividend is being divided"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x / args.y;
Ok(result)
}
}

View File

@ -1,520 +0,0 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
//! be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
//! before building the agent.
//!
//! # Example
//! ```rust
//! use rig::{
//! completion::{Chat, Completion, Prompt},
//! providers::openai,
//! };
//!
//! let openai = openai::Client::from_env();
//!
//! // Configure the agent
//! let agent = openai.agent("gpt-4o")
//! .preamble("System prompt")
//! .context("Context document 1")
//! .context("Context document 2")
//! .tool(tool1)
//! .tool(tool2)
//! .temperature(0.8)
//! .additional_params(json!({"foo": "bar"}))
//! .build();
//!
//! // Use the agent for completions and prompts
//! // Generate a chat completion response from a prompt and chat history
//! let chat_response = agent.chat("Prompt", chat_history)
//! .await
//! .expect("Failed to chat with Agent");
//!
//! // Generate a prompt completion response from a simple prompt
//! let chat_response = agent.prompt("Prompt")
//! .await
//! .expect("Failed to prompt the Agent");
//!
//! // Generate a completion request builder from a prompt and chat history. The builder
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
//! // model parameters, etc.), but these can be overwritten.
//! let completion_req_builder = agent.completion("Prompt", chat_history)
//! .await
//! .expect("Failed to create completion request builder");
//!
//! let response = completion_req_builder
//! .temperature(0.9) // Overwrite the agent's temperature
//! .send()
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let agent = openai.agent(openai::GPT_4O)
//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
///
/// # Example
/// ```
/// use rig::{completion::Prompt, providers::openai};
///
/// let openai = openai::Client::from_env();
///
/// let comedian_agent = openai
/// .agent("gpt-4o")
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
/// .temperature(0.9)
/// .build();
///
/// let response = comedian_agent.prompt("Entertain me!")
/// .await
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: String,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request
.documents(dynamic_context)
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => {
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
// TODO: tool definitions should likely take an `Option<String>`
Some(tool.definition("".into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request.tools(static_tools)
}
};
Ok(agent)
}
}
impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Prompt for &Agent<M> {
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let resp = self.completion(prompt, chat_history).await?.send().await?;
// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
}
}
}
/// A builder for creating an agent
///
/// # Example
/// ```
/// use rig::{providers::openai, agent::AgentBuilder};
///
/// let openai = openai::Client::from_env();
///
/// let gpt4o = openai.completion_model("gpt-4o");
///
/// // Configure the agent
/// let agent = AgentBuilder::new(model)
/// .preamble("System prompt")
/// .context("Context document 1")
/// .context("Context document 2")
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
/// Add a static context document to the agent
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
/// Add a static tool to the agent
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
// Add an MCP tool to the agent
#[cfg(feature = "mcp")]
pub fn mcp_tool<T: mcp_core::transport::Transport>(
mut self,
tool: mcp_core::types::Tool,
client: mcp_core::client::Client<T>,
) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
/// Build the agent
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
// Reuse the existing completion implementation to build the request
// This ensures streaming and non-streaming use the same request building logic
self.completion(prompt, chat_history).await
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()
.await
}
}

View File

@ -0,0 +1,179 @@
use std::collections::HashMap;
use crate::{
completion::{CompletionModel, Document},
tool::{Tool, ToolSet},
vector_store::VectorStoreIndexDyn,
};
#[cfg(feature = "mcp")]
use crate::tool::McpTool;
use super::Agent;
/// A builder for creating an agent
///
/// # Example
/// ```
/// use rig::{providers::openai, agent::AgentBuilder};
///
/// let openai = openai::Client::from_env();
///
/// let gpt4o = openai.completion_model("gpt-4o");
///
/// // Configure the agent
/// let agent = AgentBuilder::new(model)
/// .preamble("System prompt")
/// .context("Context document 1")
/// .context("Context document 2")
/// .tool(tool1)
/// .tool(tool2)
/// .temperature(0.8)
/// .additional_params(json!({"foo": "bar"}))
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
/// Add a static context document to the agent
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
/// Add a static tool to the agent
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
// Add an MCP tool to the agent
#[cfg(feature = "mcp")]
pub fn mcp_tool<T: mcp_core::transport::Transport>(
mut self,
tool: mcp_core::types::Tool,
client: mcp_core::client::Client<T>,
) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
/// Build the agent
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}

View File

@ -0,0 +1,244 @@
use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::ToolSet,
vector_store::VectorStoreError,
};
use super::prompt_request::PromptRequest;
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
/// All context documents and tools are always provided to the agent when prompted.
///
/// # Example
/// ```
/// use rig::{completion::Prompt, providers::openai};
///
/// let openai = openai::Client::from_env();
///
/// let comedian_agent = openai
/// .agent("gpt-4o")
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
/// .temperature(0.9)
/// .build();
///
/// let response = comedian_agent.prompt("Entertain me!")
/// .await
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
pub model: M,
/// System prompt
pub preamble: String,
/// Context documents always available to the agent
pub static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
pub static_tools: Vec<String>,
/// Temperature of the model
pub temperature: Option<f64>,
/// Maximum number of tokens for the completion
pub max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
pub additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
/// Dynamic tools
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request
.documents(dynamic_context)
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => {
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
// TODO: tool definitions should likely take an `Option<String>`
Some(tool.definition("".into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request.tools(static_tools)
}
};
Ok(agent)
}
}
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
//
// References:
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for Agent<M> {
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
PromptRequest::new(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Prompt for &Agent<M> {
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
PromptRequest::new(*self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let mut cloned_history = chat_history.clone();
PromptRequest::new(self, prompt)
.with_history(&mut cloned_history)
.await
}
}
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
// Reuse the existing completion implementation to build the request
// This ensures streaming and non-streaming use the same request building logic
self.completion(prompt, chat_history).await
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()
.await
}
}

116
rig-core/src/agent/mod.rs Normal file
View File

@ -0,0 +1,116 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [crate::completion::Completion] and [crate::completion::Prompt] traits,
//! allowing it to be used for generating completions responses and prompts. The [Agent] struct also
//! implements the [crate::completion::Chat] trait, which allows it to be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
//! before building the agent.
//!
//! # Example
//! ```rust
//! use rig::{
//! completion::{Chat, Completion, Prompt},
//! providers::openai,
//! };
//!
//! let openai = openai::Client::from_env();
//!
//! // Configure the agent
//! let agent = openai.agent("gpt-4o")
//! .preamble("System prompt")
//! .context("Context document 1")
//! .context("Context document 2")
//! .tool(tool1)
//! .tool(tool2)
//! .temperature(0.8)
//! .additional_params(json!({"foo": "bar"}))
//! .build();
//!
//! // Use the agent for completions and prompts
//! // Generate a chat completion response from a prompt and chat history
//! let chat_response = agent.chat("Prompt", chat_history)
//! .await
//! .expect("Failed to chat with Agent");
//!
//! // Generate a prompt completion response from a simple prompt
//! let chat_response = agent.prompt("Prompt")
//! .await
//! .expect("Failed to prompt the Agent");
//!
//! // Generate a completion request builder from a prompt and chat history. The builder
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
//! // model parameters, etc.), but these can be overwritten.
//! let completion_req_builder = agent.completion("Prompt", chat_history)
//! .await
//! .expect("Failed to create completion request builder");
//!
//! let response = completion_req_builder
//! .temperature(0.9) // Overwrite the agent's temperature
//! .send()
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let agent = openai.agent(openai::GPT_4O)
//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
mod builder;
mod completion;
mod prompt_request;
pub use builder::AgentBuilder;
pub use completion::Agent;
pub use prompt_request::PromptRequest;

View File

@ -0,0 +1,173 @@
use std::future::IntoFuture;
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
use crate::{
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
message::{AssistantContent, UserContent},
tool::ToolSetError,
OneOrMany,
};
use super::Agent;
/// A builder for creating prompt requests with customizable options.
/// Uses generics to track which options have been set during the build process.
pub struct PromptRequest<'a, M: CompletionModel> {
/// The prompt message to send to the model
prompt: Message,
/// Optional chat history to include with the prompt
/// Note: chat history needs to outlive the agent as it might be used with other agents
chat_history: Option<&'a mut Vec<Message>>,
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
max_depth: usize,
/// The agent to use for execution
agent: &'a Agent<M>,
}
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
/// Create a new PromptRequest with the given prompt and model
pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
Self {
prompt: prompt.into(),
chat_history: None,
max_depth: 0,
agent,
}
}
}
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
/// Set the maximum depth for multi-turn conversations
pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: depth,
agent: self.agent,
}
}
/// Add chat history to the prompt request
pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
PromptRequest {
prompt: self.prompt,
chat_history: Some(history),
max_depth: self.max_depth,
agent: self.agent,
}
}
}
/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
/// for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
/// directly via the associated type.
impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
type Output = Result<String, PromptError>;
type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
fn into_future(self) -> Self::IntoFuture {
self.send().boxed()
}
}
impl<M: CompletionModel> PromptRequest<'_, M> {
async fn send(self) -> Result<String, PromptError> {
let agent = self.agent;
let mut prompt = self.prompt;
let chat_history = if let Some(history) = self.chat_history {
history
} else {
&mut Vec::new()
};
let mut current_max_depth = 0;
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
while current_max_depth <= self.max_depth + 1 {
current_max_depth += 1;
if self.max_depth > 1 {
tracing::info!(
"Current conversation depth: {}/{}",
current_max_depth,
self.max_depth
);
}
let resp = agent
.completion(prompt.clone(), chat_history.to_vec())
.await?
.send()
.await?;
chat_history.push(prompt);
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
chat_history.push(Message::Assistant {
content: resp.choice.clone(),
});
if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
if self.max_depth > 1 {
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
}
// If there are no tool calls, depth is not relevant, we can just return the merged text.
return Ok(merged_texts);
}
let tool_content = stream::iter(tool_calls)
.then(async |choice| {
if let AssistantContent::ToolCall(tool_call) = choice {
let output = agent
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;
Ok(UserContent::tool_result(
tool_call.id.clone(),
OneOrMany::one(output.into()),
))
} else {
unreachable!(
"This should never happen as we already filtered for `ToolCall`"
)
}
})
.collect::<Vec<Result<UserContent, ToolSetError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
prompt = Message::User {
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
};
}
// If we reach here, we never resolved the final tool call. We need to do ... something.
Err(PromptError::MaxDepthError {
max_depth: self.max_depth,
chat_history: chat_history.clone(),
prompt,
})
}
}

View File

@ -229,6 +229,16 @@ impl Message {
content: OneOrMany::one(AssistantContent::text(text)),
}
}
// Helper constructor to make creating tool result messages easier.
pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: id.into(),
content: OneOrMany::one(ToolResultContent::text(content)),
})),
}
}
}
impl UserContent {
@ -467,6 +477,12 @@ impl From<String> for Text {
}
}
impl From<&String> for Text {
fn from(text: &String) -> Self {
text.to_owned().into()
}
}
impl From<&str> for Text {
fn from(text: &str) -> Self {
text.to_owned().into()
@ -497,6 +513,14 @@ impl From<&str> for Message {
}
}
impl From<&String> for Message {
fn from(text: &String) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text.into())),
}
}
}
impl From<Text> for Message {
fn from(text: Text) -> Self {
Message::User {

View File

@ -75,7 +75,7 @@ use crate::{
tool::ToolSetError,
};
use super::message::AssistantContent;
use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
// Errors
#[derive(Debug, Error)]
@ -108,6 +108,13 @@ pub enum PromptError {
#[error("ToolCallError: {0}")]
ToolError(#[from] ToolSetError),
#[error("MaxDepthError: (reached limit: {max_depth})")]
MaxDepthError {
max_depth: usize,
chat_history: Vec<Message>,
prompt: Message,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
@ -163,7 +170,7 @@ pub trait Prompt: Send + Sync {
fn prompt(
&self,
prompt: impl Into<Message> + Send,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
}
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
@ -180,7 +187,7 @@ pub trait Chat: Send + Sync {
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
}
/// Trait defining a low-level LLM completion interface
@ -236,12 +243,11 @@ pub trait CompletionModel: Clone + Send + Sync {
/// Struct representing a general completion request that can be sent to a completion model provider.
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: Message,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The chat history to be sent to the completion model provider
pub chat_history: Vec<Message>,
/// The very last message will always be the prompt (hense why there is *always* one)
pub chat_history: OneOrMany<Message>,
/// The documents to be sent to the completion model provider
pub documents: Vec<Document>,
/// The tools to be sent to the completion model provider
@ -255,23 +261,33 @@ pub struct CompletionRequest {
}
impl CompletionRequest {
pub fn prompt_with_context(&self) -> Message {
let mut new_prompt = self.prompt.clone();
if let Message::User { ref mut content } = new_prompt {
if !self.documents.is_empty() {
let attachments = self
.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join("");
let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
let mut new_content = vec![UserContent::text(formatted_content)];
new_content.extend(content.clone());
*content = OneOrMany::many(new_content).expect("This has more than 1 item");
}
/// Returns documents normalized into a message (if any).
/// Most providers do not accept documents directly as input, so it needs to convert into a
/// `Message` so that it can be incorperated into `chat_history` as a
pub fn normalized_documents(&self) -> Option<Message> {
if self.documents.is_empty() {
return None;
}
new_prompt
// Most providers will convert documents into a text unless it can handle document messages.
// We use `UserContent::document` for those who handle it directly!
let messages = self
.documents
.iter()
.map(|doc| {
UserContent::document(
doc.to_string(),
// In the future, we can customize `Document` to pass these extra types through.
// Most providers ditch these but they might want to use them.
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
)
})
.collect::<Vec<_>>();
Some(Message::User {
content: OneOrMany::many(messages).expect("There will be atleast one document"),
})
}
}
@ -446,10 +462,12 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
.expect("There will always be atleast the prompt");
CompletionRequest {
prompt: self.prompt,
preamble: self.preamble,
chat_history: self.chat_history,
chat_history,
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
@ -475,7 +493,6 @@ impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
#[cfg(test)]
mod tests {
use crate::OneOrMany;
use super::*;
@ -513,7 +530,7 @@ mod tests {
}
#[test]
fn test_prompt_with_context_with_documents() {
fn test_normalize_documents_with_documents() {
let doc1 = Document {
id: "doc1".to_string(),
text: "Document 1 text.".to_string(),
@ -527,9 +544,8 @@ mod tests {
};
let request = CompletionRequest {
prompt: "What is the capital of France?".into(),
preamble: None,
chat_history: Vec::new(),
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: vec![doc1, doc2],
tools: Vec::new(),
temperature: None,
@ -539,19 +555,35 @@ mod tests {
let expected = Message::User {
content: OneOrMany::many(vec![
UserContent::text(concat!(
"<attachments>\n",
"<file id: doc1>\nDocument 1 text.\n</file>\n",
"<file id: doc2>\nDocument 2 text.\n</file>\n",
"</attachments>"
)),
UserContent::text("What is the capital of France?"),
UserContent::document(
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
),
UserContent::document(
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
),
])
.expect("This has more than 1 item"),
.expect("There will be at least one document"),
};
request.prompt_with_context();
assert_eq!(request.normalized_documents(), Some(expected));
}
assert_eq!(request.prompt_with_context(), expected);
#[test]
fn test_normalize_documents_without_documents() {
let request = CompletionRequest {
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
};
assert_eq!(request.normalized_documents(), None);
}
}

View File

@ -36,10 +36,13 @@ use serde_json::json;
use crate::{
agent::{Agent, AgentBuilder},
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
message::{AssistantContent, Message, ToolCall, ToolFunction},
tool::Tool,
};
const SUBMIT_TOOL_NAME: &str = "submit";
#[derive(Debug, thiserror::Error)]
pub enum ExtractionError {
#[error("No data extracted")]
@ -48,8 +51,8 @@ pub enum ExtractionError {
#[error("Failed to deserialize the extracted data: {0}")]
DeserializationError(#[from] serde_json::Error),
#[error("PromptError: {0}")]
PromptError(#[from] PromptError),
#[error("CompletionError: {0}")]
CompletionError(#[from] CompletionError),
}
/// Extractor for structured data from text
@ -62,14 +65,43 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
where
M: Sync,
{
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
let summary = self.agent.prompt(text).await?;
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
let response = self.agent.completion(text, vec![]).await?.send().await?;
if summary.is_empty() {
return Err(ExtractionError::NoData);
let arguments = response
.choice
.into_iter()
// We filter tool calls to look for submit tool calls
.filter_map(|content| {
if let AssistantContent::ToolCall(ToolCall {
function: ToolFunction { arguments, name },
..
}) = content
{
if name == SUBMIT_TOOL_NAME {
Some(arguments)
} else {
None
}
} else {
None
}
})
.collect::<Vec<_>>();
if arguments.len() > 1 {
tracing::warn!(
"Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
);
}
Ok(serde_json::from_str(&summary)?)
let raw_data = if let Some(arg) = arguments.into_iter().next() {
arg
} else {
return Err(ExtractionError::NoData);
};
Ok(serde_json::from_value(raw_data)?)
}
}
@ -132,7 +164,7 @@ struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
struct SubmitError;
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
const NAME: &'static str = "submit";
const NAME: &'static str = SUBMIT_TOOL_NAME;
type Error = SubmitError;
type Args = T;
type Output = T;

View File

@ -1,6 +1,9 @@
use std::future::IntoFuture;
use crate::{
completion::{self, CompletionModel},
extractor::{ExtractionError, Extractor},
message::Message,
vector_store,
};
@ -79,14 +82,14 @@ impl<P, In> Prompt<P, In> {
impl<P, In> Op for Prompt<P, In>
where
P: completion::Prompt,
P: completion::Prompt + Send + Sync,
In: Into<String> + Send + Sync,
{
type Input = In;
type Output = Result<String, completion::PromptError>;
async fn call(&self, input: Self::Input) -> Self::Output {
self.prompt.prompt(input.into()).await
fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
self.prompt.prompt(input.into()).into_future()
}
}
@ -127,13 +130,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
where
M: CompletionModel,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Input: Into<String> + Send + Sync,
Input: Into<Message> + Send + Sync,
{
type Input = Input;
type Output = Result<Output, ExtractionError>;
async fn call(&self, input: Self::Input) -> Self::Output {
self.extractor.extract(&input.into()).await
self.extractor.extract(input).await
}
}
@ -159,6 +162,7 @@ pub mod tests {
pub struct MockModel;
impl Prompt for MockModel {
#[allow(refining_impl_trait)]
async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
let msg: message::Message = prompt.into();
let prompt = match msg {

View File

@ -553,26 +553,20 @@ impl completion::CompletionModel for CompletionModel {
));
};
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut full_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
full_history.push(docs);
}
full_history.extend(completion_request.chat_history);
let mut messages = completion_request
.chat_history
let full_history = full_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let mut request = json!({
"model": self.model,
"messages": messages,
"messages": full_history,
"max_tokens": max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
});

View File

@ -7,7 +7,6 @@ use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefin
use super::decoders::sse::from_response as sse_from_response;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils::merge_inplace;
use crate::message::MessageError;
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
#[derive(Debug, Deserialize)]
@ -90,26 +89,20 @@ impl StreamingCompletionModel for CompletionModel {
));
};
let prompt_message: Message = completion_request
.prompt_with_context()
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
let mut full_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
full_history.push(docs);
}
full_history.extend(completion_request.chat_history);
let mut messages = completion_request
.chat_history
let full_history = full_history
.into_iter()
.map(|message| {
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
})
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;
messages.push(prompt_message);
let mut request = json!({
"model": self.model,
"messages": messages,
"messages": full_history,
"max_tokens": max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
"stream": true,

View File

@ -480,16 +480,14 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<serde_json::Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
Some(preamble) => vec![openai::Message::system(preamble)],
None => vec![],
};
// Convert prompt to user message
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
// Convert existing chat history
if let Some(docs) = completion_request.normalized_documents() {
let docs: Vec<openai::Message> = docs.try_into()?;
full_history.extend(docs);
}
let chat_history: Vec<openai::Message> = completion_request
.chat_history
.into_iter()
@ -499,9 +497,7 @@ impl CompletionModel {
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
let request = if completion_request.tools.is_empty() {
json!({
@ -786,6 +782,7 @@ mod azure_tests {
use crate::completion::CompletionModel;
use crate::embeddings::EmbeddingModel;
use crate::OneOrMany;
#[tokio::test]
#[ignore]
@ -812,8 +809,7 @@ mod azure_tests {
let completion = model
.completion(CompletionRequest {
preamble: Some("You are a helpful assistant.".to_string()),
chat_history: vec![],
prompt: "Hello, world!".into(),
chat_history: OneOrMany::one("Hello!".into()),
documents: vec![],
max_tokens: Some(100),
temperature: Some(0.0),

View File

@ -440,29 +440,34 @@ impl completion::CompletionModel for CompletionModel {
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let prompt = completion_request.prompt_with_context();
// Build up the order of messages (context, chat_history)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
let mut messages: Vec<message::Message> =
if let Some(preamble) = completion_request.preamble {
vec![preamble.into()]
} else {
vec![]
};
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| {
vec![Message::System { content: preamble }]
});
messages.extend(completion_request.chat_history);
messages.push(prompt);
let messages: Vec<Message> = messages
.into_iter()
.map(|msg| msg.try_into())
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect();
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = json!({
"model": self.model,
"messages": messages,
"messages": full_history,
"documents": completion_request.documents,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),

View File

@ -379,28 +379,28 @@ impl DeepSeekCompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<serde_json::Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Build up the order of messages (context, chat_history, prompt)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = if completion_request.tools.is_empty() {
json!({

View File

@ -398,6 +398,13 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Build up the order of messages (context, chat_history, prompt)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message {
@ -408,19 +415,13 @@ impl CompletionModel {
None => vec![],
};
// Convert prompt to user message
let prompt: Message = completion_request.prompt_with_context().try_into()?;
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Message>, _>>()?;
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.push(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Message>, _>>()?,
);
let request = if completion_request.tools.is_empty() {
json!({

View File

@ -93,11 +93,10 @@ impl completion::CompletionModel for CompletionModel {
}
pub(crate) fn create_request_body(
mut completion_request: CompletionRequest,
completion_request: CompletionRequest,
) -> Result<GenerateContentRequest, CompletionError> {
let mut full_history = Vec::new();
full_history.append(&mut completion_request.chat_history);
full_history.push(completion_request.prompt_with_context());
full_history.extend(completion_request.chat_history);
let additional_params = completion_request
.additional_params

View File

@ -279,28 +279,31 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message {
role: "system".to_string(),
content: Some(preamble.to_string()),
}],
None => vec![],
};
// Build up the order of messages (context, chat_history, prompt)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Convert prompt to user message
let prompt: Message = completion_request.prompt_with_context().try_into()?;
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> =
completion_request
.preamble
.map_or_else(Vec::new, |preamble| {
vec![Message {
role: "system".to_string(),
content: Some(preamble),
}]
});
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Message>, _>>()?;
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.push(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Message>, _>>()?,
);
let request = if completion_request.tools.is_empty() {
json!({

View File

@ -502,8 +502,10 @@ impl CompletionModel {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
if let Some(docs) = completion_request.normalized_documents() {
let docs: Vec<Message> = docs.try_into()?;
full_history.extend(docs);
}
let chat_history: Vec<Message> = completion_request
.chat_history
@ -516,7 +518,6 @@ impl CompletionModel {
.collect();
full_history.extend(chat_history);
full_history.extend(prompt);
let model = self.client.sub_provider.model_identifier(&self.model);

View File

@ -12,6 +12,7 @@
use super::openai::{send_compatible_streaming_request, AssistantContent};
use crate::json_utils::merge_inplace;
use crate::message;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::{
agent::AgentBuilder,
@ -306,28 +307,28 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Build up the order of messages (context, chat_history, prompt)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = json!({
"model": self.model,

View File

@ -238,24 +238,25 @@ impl CompletionModel {
}));
}
// Add prompt
messages.push(match &completion_request.prompt {
Message::User { content } => {
let text = content
.iter()
.map(|c| match c {
UserContent::Text(text) => &text.text,
_ => "",
})
.collect::<Vec<_>>()
.join("\n");
serde_json::json!({
"role": "user",
"content": text
// Add docs
if let Some(Message::User { content }) = completion_request.normalized_documents() {
let text = content
.into_iter()
.filter_map(|doc| match doc {
UserContent::Document(doc) => Some(doc.data),
UserContent::Text(text) => Some(text.text),
// This should always be `Document`
_ => None,
})
}
_ => unreachable!(),
});
.collect::<Vec<_>>()
.join("\n");
messages.push(serde_json::json!({
"role": "user",
"content": text
}));
}
// Add chat history
for msg in completion_request.chat_history {

View File

@ -10,6 +10,7 @@
//! ```
use crate::json_utils::merge;
use crate::message;
use crate::providers::openai::send_compatible_streaming_request;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::{
@ -141,28 +142,30 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
Some(preamble) => vec![openai::Message::system(preamble)],
None => vec![],
};
// Build up the order of messages (context, chat_history)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Convert prompt to user message
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<openai::Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| {
vec![openai::Message::system(&preamble)]
});
// Convert existing chat history
let chat_history: Vec<openai::Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = if completion_request.tools.is_empty() {
json!({

View File

@ -325,8 +325,27 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Build up the order of messages (context, chat_history)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(|msg| msg.try_into())
.collect::<Result<Vec<Message>, _>>()?,
);
// Convert internal prompt into a provider Message
let prompt: Message = completion_request.prompt_with_context().try_into()?;
let options = if let Some(extra) = completion_request.additional_params {
json_utils::merge(
json!({ "temperature": completion_request.temperature }),
@ -336,16 +355,6 @@ impl CompletionModel {
json!({ "temperature": completion_request.temperature })
};
// Chat mode: assemble full conversation history including preamble and chat history
let mut full_history = Vec::new();
if let Some(preamble) = completion_request.preamble {
full_history.push(Message::system(&preamble));
}
for msg in completion_request.chat_history.into_iter() {
full_history.push(Message::try_from(msg)?);
}
full_history.push(prompt);
let mut request_payload = json!({
"model": self.model,
"messages": full_history,

View File

@ -605,28 +605,28 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Build up the order of messages (context, chat_history)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = if completion_request.tools.is_empty() {
json!({

View File

@ -269,8 +269,11 @@ impl completion::CompletionModel for CompletionModel {
None => vec![],
};
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Gather docs
if let Some(docs) = completion_request.normalized_documents() {
let docs: Vec<Message> = docs.try_into()?;
full_history.extend(docs);
}
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
@ -284,7 +287,6 @@ impl completion::CompletionModel for CompletionModel {
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
let request = json!({
"model": self.model,

View File

@ -204,39 +204,36 @@ impl CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add context documents to current prompt
let prompt_with_context = completion_request.prompt_with_context();
// Add preamble to messages (if available)
let mut messages: Vec<Message> = if let Some(preamble) = completion_request.preamble {
vec![Message {
role: Role::System,
content: preamble,
}]
} else {
vec![]
};
// Add chat history to messages
for message in completion_request.chat_history {
messages.push(
message
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
);
// Build up the order of messages (context, chat_history, prompt)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Add user prompt to messages
messages.push(
prompt_with_context
.try_into()
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> =
completion_request
.preamble
.map_or_else(Vec::new, |preamble| {
vec![Message {
role: Role::System,
content: preamble,
}]
});
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Message>, _>>()?,
);
// Compose request
let request = json!({
"model": self.model,
"messages": messages,
"messages": full_history,
"temperature": completion_request.temperature,
});

View File

@ -148,7 +148,10 @@ impl CompletionModel {
Some(preamble) => vec![openai::Message::system(preamble)],
None => vec![],
};
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
if let Some(docs) = completion_request.normalized_documents() {
let docs: Vec<openai::Message> = docs.try_into()?;
full_history.extend(docs);
}
let chat_history: Vec<openai::Message> = completion_request
.chat_history
.into_iter()
@ -157,8 +160,9 @@ impl CompletionModel {
.into_iter()
.flatten()
.collect();
full_history.extend(chat_history);
full_history.extend(prompt);
let mut request = if completion_request.tools.is_empty() {
json!({
"model": self.model,

View File

@ -10,7 +10,6 @@ use crate::{
};
use super::client::{xai_api_types::ApiResponse, Client};
use crate::completion::CompletionRequest;
use serde_json::{json, Value};
use xai_api_types::{CompletionResponse, ToolDefinition};
@ -30,22 +29,13 @@ pub struct CompletionModel {
impl CompletionModel {
pub(crate) fn create_completion_request(
&self,
completion_request: CompletionRequest,
completion_request: completion::CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => {
if preamble.is_empty() {
vec![]
} else {
vec![Message::system(preamble)]
}
}
None => vec![],
};
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Convert documents into user message
let docs: Option<Vec<Message>> = completion_request
.normalized_documents()
.map(|docs| docs.try_into())
.transpose()?;
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
@ -57,9 +47,19 @@ impl CompletionModel {
.flatten()
.collect();
// Combine all messages into a single history
// Init full history with preamble (or empty if non-existant)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Docs appear right after preamble, if they exist
if let Some(docs) = docs {
full_history.extend(docs)
}
// Chat history and prompt appear in the order they were provided
full_history.extend(chat_history);
full_history.extend(prompt);
let mut request = if completion_request.tools.is_empty() {
json!({

View File

@ -15,6 +15,7 @@ use rig::agent::AgentBuilder;
use rig::completion::{CompletionError, CompletionRequest};
use rig::embeddings::{EmbeddingError, EmbeddingsBuilder};
use rig::extractor::ExtractorBuilder;
use rig::message;
use rig::providers::openai::{self, Message};
use rig::OneOrMany;
use rig::{completion, embeddings, Embed};
@ -470,14 +471,19 @@ impl completion::CompletionModel for CompletionModel {
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Build up the order of messages (context, chat_history)
let mut partial_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
partial_history.push(docs);
}
partial_history.extend(completion_request.chat_history);
// Initialize full history with preamble (or empty if non-existent)
let mut full_history: Vec<Message> = completion_request
.preamble
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
tracing::info!("Try to get on-chain system prompt");
let eternal_ai_rpc = std::env::var("ETERNALAI_RPC_URL").unwrap_or_else(|_| "".to_string());
let eternal_ai_contract =
@ -515,19 +521,16 @@ impl completion::CompletionModel for CompletionModel {
}
}
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
// Convert and extend the rest of the history
full_history.extend(
partial_history
.into_iter()
.map(message::Message::try_into)
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
);
let request = if completion_request.tools.is_empty() {
json!({