mirror of https://github.com/0xplaygrounds/rig
Merge 1dd15d8f8c
into 33e8fc7a65
This commit is contained in:
commit
7245a06922
|
@ -120,13 +120,13 @@ impl completion::CompletionModel for CompletionModel {
|
|||
.model_id(self.model.as_str());
|
||||
|
||||
let tool_config = request.tools_config()?;
|
||||
let prompt_with_history = request.prompt_with_history()?;
|
||||
let messages = request.messages()?;
|
||||
converse_builder = converse_builder
|
||||
.set_additional_model_request_fields(request.additional_params())
|
||||
.set_inference_config(request.inference_config())
|
||||
.set_tool_config(tool_config)
|
||||
.set_system(request.system_prompt())
|
||||
.set_messages(Some(prompt_with_history));
|
||||
.set_messages(Some(messages));
|
||||
|
||||
let response = converse_builder
|
||||
.send()
|
||||
|
|
|
@ -28,7 +28,7 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
.model_id(self.model.as_str());
|
||||
|
||||
let tool_config = request.tools_config()?;
|
||||
let prompt_with_history = request.prompt_with_history()?;
|
||||
let prompt_with_history = request.messages()?;
|
||||
converse_builder = converse_builder
|
||||
.set_additional_model_request_fields(request.additional_params())
|
||||
.set_inference_config(request.inference_config())
|
||||
|
|
|
@ -6,6 +6,8 @@ use aws_sdk_bedrockruntime::types::{
|
|||
ToolSpecification,
|
||||
};
|
||||
use rig::completion::{CompletionError, Message};
|
||||
use rig::message::{ContentFormat, DocumentMediaType, UserContent};
|
||||
use rig::OneOrMany;
|
||||
|
||||
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
|
||||
|
||||
|
@ -69,13 +71,30 @@ impl AwsCompletionRequest {
|
|||
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
|
||||
}
|
||||
|
||||
pub fn prompt_with_history(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
|
||||
let mut chat_history = self.0.chat_history.to_owned();
|
||||
let prompt_with_context = self.0.prompt_with_context();
|
||||
|
||||
pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
|
||||
let mut full_history: Vec<Message> = Vec::new();
|
||||
full_history.append(&mut chat_history);
|
||||
full_history.push(prompt_with_context);
|
||||
|
||||
if !self.0.documents.is_empty() {
|
||||
let messages = self
|
||||
.0
|
||||
.documents
|
||||
.iter()
|
||||
.map(|doc| doc.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ");
|
||||
|
||||
let content = OneOrMany::one(UserContent::document(
|
||||
messages,
|
||||
Some(ContentFormat::String),
|
||||
Some(DocumentMediaType::TXT),
|
||||
));
|
||||
|
||||
full_history.push(Message::User { content });
|
||||
}
|
||||
|
||||
self.0.chat_history.iter().for_each(|message| {
|
||||
full_history.push(message.clone());
|
||||
});
|
||||
|
||||
full_history
|
||||
.into_iter()
|
||||
|
|
|
@ -11,7 +11,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
|
||||
// Create agent with a single context prompt
|
||||
let comedian_agent = client
|
||||
.agent("cognitivecomputations/dolphin3.0-mistral-24b:free")
|
||||
.agent("google/gemini-2.5-pro-exp-03-25:free")
|
||||
.preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
.build();
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env;
|
|||
use anyhow::Result;
|
||||
use rig::{
|
||||
agent::Agent,
|
||||
completion::Chat,
|
||||
completion::Prompt,
|
||||
message::Message,
|
||||
providers::{cohere, openai},
|
||||
};
|
||||
|
@ -49,20 +49,20 @@ impl Debater {
|
|||
|
||||
let resp_a = self
|
||||
.gpt_4
|
||||
.chat(prompt_a.as_str(), history_a.clone())
|
||||
.prompt(prompt_a.as_str())
|
||||
.with_history(&mut history_a)
|
||||
.await?;
|
||||
println!("GPT-4:\n{}", resp_a);
|
||||
history_a.push(Message::user(prompt_a));
|
||||
history_a.push(Message::assistant(resp_a.clone()));
|
||||
println!("================================================================");
|
||||
|
||||
let resp_b = self.coral.chat(resp_a.as_str(), history_b.clone()).await?;
|
||||
let resp_b = self
|
||||
.coral
|
||||
.prompt(resp_a.as_str())
|
||||
.with_history(&mut history_b)
|
||||
.await?;
|
||||
println!("Coral:\n{}", resp_b);
|
||||
println!("================================================================");
|
||||
|
||||
history_b.push(Message::user(resp_a));
|
||||
history_b.push(Message::assistant(resp_b.clone()));
|
||||
|
||||
last_resp_b = Some(resp_b)
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
|
|||
}
|
||||
|
||||
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
|
||||
#[allow(refining_impl_trait)]
|
||||
async fn chat(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
|
|
|
@ -1,99 +1,23 @@
|
|||
use rig::{
|
||||
agent::Agent,
|
||||
completion::{self, Completion, PromptError, ToolDefinition},
|
||||
message::{AssistantContent, Message, ToolCall, ToolFunction, ToolResultContent, UserContent},
|
||||
completion::{Prompt, ToolDefinition},
|
||||
providers::anthropic,
|
||||
tool::Tool,
|
||||
OneOrMany,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
struct MultiTurnAgent<M: rig::completion::CompletionModel> {
|
||||
agent: Agent<M>,
|
||||
chat_history: Vec<completion::Message>,
|
||||
}
|
||||
|
||||
impl<M: rig::completion::CompletionModel> MultiTurnAgent<M> {
|
||||
async fn multi_turn_prompt(
|
||||
&mut self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
) -> Result<String, PromptError> {
|
||||
let mut current_prompt: Message = prompt.into();
|
||||
loop {
|
||||
println!("Current Prompt: {:?}\n", current_prompt);
|
||||
let resp = self
|
||||
.agent
|
||||
.completion(current_prompt.clone(), self.chat_history.clone())
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let mut final_text = None;
|
||||
|
||||
for content in resp.choice.into_iter() {
|
||||
match content {
|
||||
AssistantContent::Text(text) => {
|
||||
println!("Intermediate Response: {:?}\n", text.text);
|
||||
final_text = Some(text.text.clone());
|
||||
self.chat_history.push(current_prompt.clone());
|
||||
let response_message = Message::Assistant {
|
||||
content: OneOrMany::one(AssistantContent::text(&text.text)),
|
||||
};
|
||||
self.chat_history.push(response_message);
|
||||
}
|
||||
AssistantContent::ToolCall(content) => {
|
||||
self.chat_history.push(current_prompt.clone());
|
||||
let tool_call_msg = AssistantContent::ToolCall(content.clone());
|
||||
println!("Tool Call Msg: {:?}\n", tool_call_msg);
|
||||
|
||||
self.chat_history.push(Message::Assistant {
|
||||
content: OneOrMany::one(tool_call_msg),
|
||||
});
|
||||
|
||||
let ToolCall {
|
||||
id,
|
||||
function: ToolFunction { name, arguments },
|
||||
} = content;
|
||||
|
||||
let tool_result =
|
||||
self.agent.tools.call(&name, arguments.to_string()).await?;
|
||||
|
||||
current_prompt = Message::User {
|
||||
content: OneOrMany::one(UserContent::tool_result(
|
||||
id,
|
||||
OneOrMany::one(ToolResultContent::text(tool_result)),
|
||||
)),
|
||||
};
|
||||
|
||||
final_text = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(text) = final_text {
|
||||
return Ok(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// tracing_subscriber::registry()
|
||||
// .with(
|
||||
// tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
// .unwrap_or_else(|_| "stdout=info".into()),
|
||||
// )
|
||||
// .with(tracing_subscriber::fmt::layer())
|
||||
// .init();
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
// Create OpenAI client
|
||||
let openai_client = anthropic::Client::from_env();
|
||||
|
||||
// Create RAG agent with a single context prompt and a dynamic tool source
|
||||
let calculator_rag = openai_client
|
||||
let agent = openai_client
|
||||
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||
.preamble(
|
||||
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||
|
@ -109,21 +33,18 @@ async fn main() -> anyhow::Result<()> {
|
|||
.tool(Divide)
|
||||
.build();
|
||||
|
||||
let mut agent = MultiTurnAgent {
|
||||
agent: calculator_rag,
|
||||
chat_history: Vec::new(),
|
||||
};
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let result = agent
|
||||
.multi_turn_prompt("Calculate 5 - 2 = ?. Describe the result to me.")
|
||||
.prompt("Calculate 5 - 2 = ?. Describe the result to me.")
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
println!("\n\nOpenAI Calculator Agent: {}", result);
|
||||
|
||||
// Prompt the agent again and print the response
|
||||
let result = agent
|
||||
.multi_turn_prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
|
||||
.prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
println!("\n\nOpenAI Calculator Agent: {}", result);
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
use rig::{
|
||||
agent::Agent,
|
||||
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||
extractor::Extractor,
|
||||
message::Message,
|
||||
providers::anthropic,
|
||||
tool::Tool,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
const CHAIN_OF_THOUGHT_PROMPT: &str = "
|
||||
You are an assistant that extracts reasoning steps from a given prompt.
|
||||
Do not return text, only return a tool call.
|
||||
";
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
|
||||
struct ChainOfThoughtSteps {
|
||||
steps: Vec<String>,
|
||||
}
|
||||
|
||||
struct ReasoningAgent<M: CompletionModel> {
|
||||
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
|
||||
executor: Agent<M>,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
|
||||
#[allow(refining_impl_trait)]
|
||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||
let prompt: Message = prompt.into();
|
||||
let mut chat_history = vec![prompt.clone()];
|
||||
let extracted = self
|
||||
.chain_of_thought_extractor
|
||||
.extract(prompt)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Extraction error: {:?}", e);
|
||||
CompletionError::ProviderError("".into())
|
||||
})?;
|
||||
|
||||
if extracted.steps.is_empty() {
|
||||
return Ok("No reasoning steps provided.".into());
|
||||
}
|
||||
|
||||
let mut reasoning_prompt = String::new();
|
||||
for (i, step) in extracted.steps.iter().enumerate() {
|
||||
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.executor
|
||||
.prompt(reasoning_prompt.as_str())
|
||||
.with_history(&mut chat_history)
|
||||
.multi_turn(20)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
"full chat history generated: {}",
|
||||
serde_json::to_string_pretty(&chat_history).unwrap()
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
// Create OpenAI client
|
||||
let anthropic_client = anthropic::Client::from_env();
|
||||
|
||||
let agent = ReasoningAgent {
|
||||
chain_of_thought_extractor: anthropic_client
|
||||
.extractor(anthropic::CLAUDE_3_5_SONNET)
|
||||
.preamble(CHAIN_OF_THOUGHT_PROMPT)
|
||||
.build(),
|
||||
|
||||
executor: anthropic_client
|
||||
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||
.preamble(
|
||||
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||
Follow these instructions closely.
|
||||
1. Consider the user's request carefully and identify the core elements of the request.
|
||||
2. Select which tool among those made available to you is appropriate given the context.
|
||||
3. This is very important: never perform the operation yourself.
|
||||
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
|
||||
"
|
||||
)
|
||||
.tool(Add)
|
||||
.tool(Subtract)
|
||||
.tool(Multiply)
|
||||
.tool(Divide)
|
||||
.build(),
|
||||
};
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let result = agent
|
||||
.prompt("Calculate ((15 + 25) * (100 - 50)) / (200 / (10 + 10))")
|
||||
.await?;
|
||||
|
||||
println!("\n\nReasoning Agent Chat History: {}", result);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Add;
|
||||
impl Tool for Add {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "add",
|
||||
"description": "Add x and y together",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Subtract;
|
||||
impl Tool for Subtract {
|
||||
const NAME: &'static str = "subtract";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "subtract",
|
||||
"description": "Subtract y from x (i.e.: x - y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract from"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The number to subtract"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x - args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct Multiply;
|
||||
impl Tool for Multiply {
|
||||
const NAME: &'static str = "multiply";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "multiply",
|
||||
"description": "Compute the product of x and y (i.e.: x * y)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first factor in the product"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second factor in the product"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x * args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct Divide;
|
||||
impl Tool for Divide {
|
||||
const NAME: &'static str = "divide";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
"name": "divide",
|
||||
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The Dividend of the division. The number being divided"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The Divisor of the division. The number by which the dividend is being divided"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x / args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -1,520 +0,0 @@
|
|||
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||
//!
|
||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||
//!
|
||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||
//! context documents and tools.
|
||||
//!
|
||||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
||||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
||||
//! be used for generating chat completions.
|
||||
//!
|
||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::{Chat, Completion, Prompt},
|
||||
//! providers::openai,
|
||||
//! };
|
||||
//!
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the agent
|
||||
//! let agent = openai.agent("gpt-4o")
|
||||
//! .preamble("System prompt")
|
||||
//! .context("Context document 1")
|
||||
//! .context("Context document 2")
|
||||
//! .tool(tool1)
|
||||
//! .tool(tool2)
|
||||
//! .temperature(0.8)
|
||||
//! .additional_params(json!({"foo": "bar"}))
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the agent for completions and prompts
|
||||
//! // Generate a chat completion response from a prompt and chat history
|
||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to chat with Agent");
|
||||
//!
|
||||
//! // Generate a prompt completion response from a simple prompt
|
||||
//! let chat_response = agent.prompt("Prompt")
|
||||
//! .await
|
||||
//! .expect("Failed to prompt the Agent");
|
||||
//!
|
||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||
//! // model parameters, etc.), but these can be overwritten.
|
||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to create completion request builder");
|
||||
//!
|
||||
//! let response = completion_req_builder
|
||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||
//! .send()
|
||||
//! .await
|
||||
//! .expect("Failed to send completion request");
|
||||
//! ```
|
||||
//!
|
||||
//! RAG Agent example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::Prompt,
|
||||
//! embeddings::EmbeddingsBuilder,
|
||||
//! providers::openai,
|
||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||
//! };
|
||||
//!
|
||||
//! // Initialize OpenAI client
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Initialize OpenAI embedding model
|
||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||
//!
|
||||
//! // Create vector store, compute embeddings and load them in the store
|
||||
//! let mut vector_store = InMemoryVectorStore::default();
|
||||
//!
|
||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .expect("Failed to build embeddings");
|
||||
//!
|
||||
//! vector_store.add_documents(embeddings)
|
||||
//! .await
|
||||
//! .expect("Failed to add documents");
|
||||
//!
|
||||
//! // Create vector store index
|
||||
//! let index = vector_store.index(embedding_model);
|
||||
//!
|
||||
//! let agent = openai.agent(openai::GPT_4O)
|
||||
//! .preamble("
|
||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||
//! You will find additional non-standard word definitions that could be useful below.
|
||||
//! ")
|
||||
//! .dynamic_context(1, index)
|
||||
//! .build();
|
||||
//!
|
||||
//! // Prompt the agent and print the response
|
||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||
//! .expect("Failed to prompt the agent");
|
||||
//! ```
|
||||
use std::collections::HashMap;
|
||||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
message::AssistantContent,
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
tool::{Tool, ToolSet},
|
||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{completion::Prompt, providers::openai};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let comedian_agent = openai
|
||||
/// .agent("gpt-4o")
|
||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
/// .temperature(0.9)
|
||||
/// .build();
|
||||
///
|
||||
/// let response = comedian_agent.prompt("Entertain me!")
|
||||
/// .await
|
||||
/// .expect("Failed to prompt the agent");
|
||||
/// ```
|
||||
pub struct Agent<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: String,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (identified by their name)
|
||||
static_tools: Vec<String>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Actual tool implementations
|
||||
pub tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||
async fn completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
let prompt = prompt.into();
|
||||
let rag_text = prompt.rag_text().clone();
|
||||
|
||||
let completion_request = self
|
||||
.model
|
||||
.completion_request(prompt)
|
||||
.preamble(self.preamble.clone())
|
||||
.messages(chat_history)
|
||||
.temperature_opt(self.temperature)
|
||||
.max_tokens_opt(self.max_tokens)
|
||||
.additional_params_opt(self.additional_params.clone())
|
||||
.documents(self.static_context.clone());
|
||||
|
||||
let agent = match &rag_text {
|
||||
Some(text) => {
|
||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id, doc)| {
|
||||
// Pretty print the document if possible for better readability
|
||||
let text = serde_json::to_string_pretty(&doc)
|
||||
.unwrap_or_else(|_| doc.to_string());
|
||||
|
||||
Document {
|
||||
id,
|
||||
text,
|
||||
additional_props: HashMap::new(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
acc.extend(docs);
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n_ids(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id)| id)
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
for doc in docs {
|
||||
if let Some(tool) = self.tools.get(&doc) {
|
||||
acc.push(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||
}
|
||||
}
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
Some(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request
|
||||
.documents(dynamic_context)
|
||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||
}
|
||||
None => {
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
// TODO: tool definitions should likely take an `Option<String>`
|
||||
Some(tool.definition("".into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request.tools(static_tools)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||
self.chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||
self.chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Chat for Agent<M> {
|
||||
async fn chat(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<String, PromptError> {
|
||||
let resp = self.completion(prompt, chat_history).await?.send().await?;
|
||||
|
||||
// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
|
||||
match resp.choice.first() {
|
||||
AssistantContent::Text(text) => Ok(text.text.clone()),
|
||||
AssistantContent::ToolCall(tool_call) => Ok(self
|
||||
.tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for creating an agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = AgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .context("Context document 1")
|
||||
/// .context("Context document 2")
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct AgentBuilder<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: Option<String>,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (by name)
|
||||
static_tools: Vec<String>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Actual tool implementations
|
||||
tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> AgentBuilder<M> {
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
preamble: None,
|
||||
static_context: vec![],
|
||||
static_tools: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
dynamic_context: vec![],
|
||||
dynamic_tools: vec![],
|
||||
tools: ToolSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the system prompt
|
||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||
self.preamble = Some(preamble.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Append to the preamble of the agent
|
||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||
self.preamble = Some(format!(
|
||||
"{}\n{}",
|
||||
self.preamble.unwrap_or_else(|| "".into()),
|
||||
doc
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static context document to the agent
|
||||
pub fn context(mut self, doc: &str) -> Self {
|
||||
self.static_context.push(Document {
|
||||
id: format!("static_doc_{}", self.static_context.len()),
|
||||
text: doc.into(),
|
||||
additional_props: HashMap::new(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static tool to the agent
|
||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||
let toolname = tool.name();
|
||||
self.tools.add_tool(tool);
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
// Add an MCP tool to the agent
|
||||
#[cfg(feature = "mcp")]
|
||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||
mut self,
|
||||
tool: mcp_core::types::Tool,
|
||||
client: mcp_core::client::Client<T>,
|
||||
) -> Self {
|
||||
let toolname = tool.name.clone();
|
||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||
/// dynamic context will be inserted in the request.
|
||||
pub fn dynamic_context(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||
) -> Self {
|
||||
self.dynamic_context
|
||||
.push((sample, Box::new(dynamic_context)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||
/// dynamic toolset will be inserted in the request.
|
||||
pub fn dynamic_tools(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||
toolset: ToolSet,
|
||||
) -> Self {
|
||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||
self.tools.add_tools(toolset);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature of the model
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of tokens for the completion
|
||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set additional parameters to be passed to the model
|
||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the agent
|
||||
pub fn build(self) -> Agent<M> {
|
||||
Agent {
|
||||
model: self.model,
|
||||
preamble: self.preamble.unwrap_or_default(),
|
||||
static_context: self.static_context,
|
||||
static_tools: self.static_tools,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params,
|
||||
dynamic_context: self.dynamic_context,
|
||||
dynamic_tools: self.dynamic_tools,
|
||||
tools: self.tools,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||
async fn stream_completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
// Reuse the existing completion implementation to build the request
|
||||
// This ensures streaming and non-streaming use the same request building logic
|
||||
self.completion(prompt, chat_history).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,179 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionModel, Document},
|
||||
tool::{Tool, ToolSet},
|
||||
vector_store::VectorStoreIndexDyn,
|
||||
};
|
||||
|
||||
#[cfg(feature = "mcp")]
|
||||
use crate::tool::McpTool;
|
||||
|
||||
use super::Agent;
|
||||
|
||||
/// A builder for creating an agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||
///
|
||||
/// // Configure the agent
|
||||
/// let agent = AgentBuilder::new(model)
|
||||
/// .preamble("System prompt")
|
||||
/// .context("Context document 1")
|
||||
/// .context("Context document 2")
|
||||
/// .tool(tool1)
|
||||
/// .tool(tool2)
|
||||
/// .temperature(0.8)
|
||||
/// .additional_params(json!({"foo": "bar"}))
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct AgentBuilder<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
model: M,
|
||||
/// System prompt
|
||||
preamble: Option<String>,
|
||||
/// Context documents always available to the agent
|
||||
static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (by name)
|
||||
static_tools: Vec<String>,
|
||||
/// Additional parameters to be passed to the model
|
||||
additional_params: Option<serde_json::Value>,
|
||||
/// Maximum number of tokens for the completion
|
||||
max_tokens: Option<u64>,
|
||||
/// List of vector store, with the sample number
|
||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||
/// Temperature of the model
|
||||
temperature: Option<f64>,
|
||||
/// Actual tool implementations
|
||||
tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> AgentBuilder<M> {
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
preamble: None,
|
||||
static_context: vec![],
|
||||
static_tools: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
dynamic_context: vec![],
|
||||
dynamic_tools: vec![],
|
||||
tools: ToolSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the system prompt
|
||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||
self.preamble = Some(preamble.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Append to the preamble of the agent
|
||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||
self.preamble = Some(format!(
|
||||
"{}\n{}",
|
||||
self.preamble.unwrap_or_else(|| "".into()),
|
||||
doc
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static context document to the agent
|
||||
pub fn context(mut self, doc: &str) -> Self {
|
||||
self.static_context.push(Document {
|
||||
id: format!("static_doc_{}", self.static_context.len()),
|
||||
text: doc.into(),
|
||||
additional_props: HashMap::new(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a static tool to the agent
|
||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||
let toolname = tool.name();
|
||||
self.tools.add_tool(tool);
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
// Add an MCP tool to the agent
|
||||
#[cfg(feature = "mcp")]
|
||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||
mut self,
|
||||
tool: mcp_core::types::Tool,
|
||||
client: mcp_core::client::Client<T>,
|
||||
) -> Self {
|
||||
let toolname = tool.name.clone();
|
||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||
self.static_tools.push(toolname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||
/// dynamic context will be inserted in the request.
|
||||
pub fn dynamic_context(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||
) -> Self {
|
||||
self.dynamic_context
|
||||
.push((sample, Box::new(dynamic_context)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||
/// dynamic toolset will be inserted in the request.
|
||||
pub fn dynamic_tools(
|
||||
mut self,
|
||||
sample: usize,
|
||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||
toolset: ToolSet,
|
||||
) -> Self {
|
||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||
self.tools.add_tools(toolset);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature of the model
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of tokens for the completion
|
||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set additional parameters to be passed to the model
|
||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the agent
|
||||
pub fn build(self) -> Agent<M> {
|
||||
Agent {
|
||||
model: self.model,
|
||||
preamble: self.preamble.unwrap_or_default(),
|
||||
static_context: self.static_context,
|
||||
static_tools: self.static_tools,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params,
|
||||
dynamic_context: self.dynamic_context,
|
||||
dynamic_tools: self.dynamic_tools,
|
||||
tools: self.tools,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,244 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use futures::{stream, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||
Message, Prompt, PromptError,
|
||||
},
|
||||
streaming::{
|
||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||
StreamingResult,
|
||||
},
|
||||
tool::ToolSet,
|
||||
vector_store::VectorStoreError,
|
||||
};
|
||||
|
||||
use super::prompt_request::PromptRequest;
|
||||
|
||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||
/// All context documents and tools are always provided to the agent when prompted.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::{completion::Prompt, providers::openai};
|
||||
///
|
||||
/// let openai = openai::Client::from_env();
|
||||
///
|
||||
/// let comedian_agent = openai
|
||||
/// .agent("gpt-4o")
|
||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
/// .temperature(0.9)
|
||||
/// .build();
|
||||
///
|
||||
/// let response = comedian_agent.prompt("Entertain me!")
|
||||
/// .await
|
||||
/// .expect("Failed to prompt the agent");
|
||||
/// ```
|
||||
pub struct Agent<M: CompletionModel> {
|
||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||
pub model: M,
|
||||
/// System prompt
|
||||
pub preamble: String,
|
||||
/// Context documents always available to the agent
|
||||
pub static_context: Vec<Document>,
|
||||
/// Tools that are always available to the agent (identified by their name)
|
||||
pub static_tools: Vec<String>,
|
||||
/// Temperature of the model
|
||||
pub temperature: Option<f64>,
|
||||
/// Maximum number of tokens for the completion
|
||||
pub max_tokens: Option<u64>,
|
||||
/// Additional parameters to be passed to the model
|
||||
pub additional_params: Option<serde_json::Value>,
|
||||
/// List of vector store, with the sample number
|
||||
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||
/// Dynamic tools
|
||||
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||
/// Actual tool implementations
|
||||
pub tools: ToolSet,
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||
async fn completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
let prompt = prompt.into();
|
||||
let rag_text = prompt.rag_text().clone();
|
||||
|
||||
let completion_request = self
|
||||
.model
|
||||
.completion_request(prompt)
|
||||
.preamble(self.preamble.clone())
|
||||
.messages(chat_history)
|
||||
.temperature_opt(self.temperature)
|
||||
.max_tokens_opt(self.max_tokens)
|
||||
.additional_params_opt(self.additional_params.clone())
|
||||
.documents(self.static_context.clone());
|
||||
|
||||
let agent = match &rag_text {
|
||||
Some(text) => {
|
||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id, doc)| {
|
||||
// Pretty print the document if possible for better readability
|
||||
let text = serde_json::to_string_pretty(&doc)
|
||||
.unwrap_or_else(|_| doc.to_string());
|
||||
|
||||
Document {
|
||||
id,
|
||||
text,
|
||||
additional_props: HashMap::new(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
acc.extend(docs);
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||
.then(|(num_sample, index)| async {
|
||||
Ok::<_, VectorStoreError>(
|
||||
index
|
||||
.top_n_ids(text, *num_sample)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(_, id)| id)
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.try_fold(vec![], |mut acc, docs| async {
|
||||
for doc in docs {
|
||||
if let Some(tool) = self.tools.get(&doc) {
|
||||
acc.push(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||
}
|
||||
}
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
Some(tool.definition(text.into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request
|
||||
.documents(dynamic_context)
|
||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||
}
|
||||
None => {
|
||||
let static_tools = stream::iter(self.static_tools.iter())
|
||||
.filter_map(|toolname| async move {
|
||||
if let Some(tool) = self.tools.get(toolname) {
|
||||
// TODO: tool definitions should likely take an `Option<String>`
|
||||
Some(tool.definition("".into()).await)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Tool implementation not found in toolset: {}",
|
||||
toolname
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
completion_request.tools(static_tools)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
}
|
||||
|
||||
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
||||
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
||||
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
||||
//
|
||||
// References:
|
||||
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||
PromptRequest::new(self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||
PromptRequest::new(*self, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(refining_impl_trait)]
|
||||
impl<M: CompletionModel> Chat for Agent<M> {
|
||||
async fn chat(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<String, PromptError> {
|
||||
let mut cloned_history = chat_history.clone();
|
||||
PromptRequest::new(self, prompt)
|
||||
.with_history(&mut cloned_history)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||
async fn stream_completion(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||
// Reuse the existing completion implementation to build the request
|
||||
// This ensures streaming and non-streaming use the same request building logic
|
||||
self.completion(prompt, chat_history).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_chat(prompt, vec![]).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||
async fn stream_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
chat_history: Vec<Message>,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
self.stream_completion(prompt, chat_history)
|
||||
.await?
|
||||
.stream()
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||
//!
|
||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||
//!
|
||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||
//! context documents and tools.
|
||||
//!
|
||||
//! The [Agent] struct implements the [crate::completion::Completion] and [crate::completion::Prompt] traits,
|
||||
//! allowing it to be used for generating completions responses and prompts. The [Agent] struct also
|
||||
//! implements the [crate::completion::Chat] trait, which allows it to be used for generating chat completions.
|
||||
//!
|
||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||
//! before building the agent.
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::{Chat, Completion, Prompt},
|
||||
//! providers::openai,
|
||||
//! };
|
||||
//!
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Configure the agent
|
||||
//! let agent = openai.agent("gpt-4o")
|
||||
//! .preamble("System prompt")
|
||||
//! .context("Context document 1")
|
||||
//! .context("Context document 2")
|
||||
//! .tool(tool1)
|
||||
//! .tool(tool2)
|
||||
//! .temperature(0.8)
|
||||
//! .additional_params(json!({"foo": "bar"}))
|
||||
//! .build();
|
||||
//!
|
||||
//! // Use the agent for completions and prompts
|
||||
//! // Generate a chat completion response from a prompt and chat history
|
||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to chat with Agent");
|
||||
//!
|
||||
//! // Generate a prompt completion response from a simple prompt
|
||||
//! let chat_response = agent.prompt("Prompt")
|
||||
//! .await
|
||||
//! .expect("Failed to prompt the Agent");
|
||||
//!
|
||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||
//! // model parameters, etc.), but these can be overwritten.
|
||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||
//! .await
|
||||
//! .expect("Failed to create completion request builder");
|
||||
//!
|
||||
//! let response = completion_req_builder
|
||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||
//! .send()
|
||||
//! .await
|
||||
//! .expect("Failed to send completion request");
|
||||
//! ```
|
||||
//!
|
||||
//! RAG Agent example
|
||||
//! ```rust
|
||||
//! use rig::{
|
||||
//! completion::Prompt,
|
||||
//! embeddings::EmbeddingsBuilder,
|
||||
//! providers::openai,
|
||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||
//! };
|
||||
//!
|
||||
//! // Initialize OpenAI client
|
||||
//! let openai = openai::Client::from_env();
|
||||
//!
|
||||
//! // Initialize OpenAI embedding model
|
||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||
//!
|
||||
//! // Create vector store, compute embeddings and load them in the store
|
||||
//! let mut vector_store = InMemoryVectorStore::default();
|
||||
//!
|
||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .expect("Failed to build embeddings");
|
||||
//!
|
||||
//! vector_store.add_documents(embeddings)
|
||||
//! .await
|
||||
//! .expect("Failed to add documents");
|
||||
//!
|
||||
//! // Create vector store index
|
||||
//! let index = vector_store.index(embedding_model);
|
||||
//!
|
||||
//! let agent = openai.agent(openai::GPT_4O)
|
||||
//! .preamble("
|
||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||
//! You will find additional non-standard word definitions that could be useful below.
|
||||
//! ")
|
||||
//! .dynamic_context(1, index)
|
||||
//! .build();
|
||||
//!
|
||||
//! // Prompt the agent and print the response
|
||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||
//! .expect("Failed to prompt the agent");
|
||||
//! ```
|
||||
|
||||
mod builder;
|
||||
mod completion;
|
||||
mod prompt_request;
|
||||
|
||||
pub use builder::AgentBuilder;
|
||||
pub use completion::Agent;
|
||||
pub use prompt_request::PromptRequest;
|
|
@ -0,0 +1,173 @@
|
|||
use std::future::IntoFuture;
|
||||
|
||||
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
|
||||
|
||||
use crate::{
|
||||
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
|
||||
message::{AssistantContent, UserContent},
|
||||
tool::ToolSetError,
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
use super::Agent;
|
||||
|
||||
/// A builder for creating prompt requests with customizable options.
|
||||
/// Uses generics to track which options have been set during the build process.
|
||||
pub struct PromptRequest<'a, M: CompletionModel> {
|
||||
/// The prompt message to send to the model
|
||||
prompt: Message,
|
||||
/// Optional chat history to include with the prompt
|
||||
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
||||
chat_history: Option<&'a mut Vec<Message>>,
|
||||
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
||||
max_depth: usize,
|
||||
/// The agent to use for execution
|
||||
agent: &'a Agent<M>,
|
||||
}
|
||||
|
||||
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
|
||||
/// Create a new PromptRequest with the given prompt and model
|
||||
pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
|
||||
Self {
|
||||
prompt: prompt.into(),
|
||||
chat_history: None,
|
||||
max_depth: 0,
|
||||
agent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
|
||||
/// Set the maximum depth for multi-turn conversations
|
||||
pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: self.chat_history,
|
||||
max_depth: depth,
|
||||
agent: self.agent,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add chat history to the prompt request
|
||||
pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
|
||||
PromptRequest {
|
||||
prompt: self.prompt,
|
||||
chat_history: Some(history),
|
||||
max_depth: self.max_depth,
|
||||
agent: self.agent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
|
||||
/// for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
|
||||
/// directly via the associated type.
|
||||
impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
|
||||
type Output = Result<String, PromptError>;
|
||||
type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
self.send().boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: CompletionModel> PromptRequest<'_, M> {
|
||||
async fn send(self) -> Result<String, PromptError> {
|
||||
let agent = self.agent;
|
||||
let mut prompt = self.prompt;
|
||||
let chat_history = if let Some(history) = self.chat_history {
|
||||
history
|
||||
} else {
|
||||
&mut Vec::new()
|
||||
};
|
||||
|
||||
let mut current_max_depth = 0;
|
||||
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
|
||||
while current_max_depth <= self.max_depth + 1 {
|
||||
current_max_depth += 1;
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!(
|
||||
"Current conversation depth: {}/{}",
|
||||
current_max_depth,
|
||||
self.max_depth
|
||||
);
|
||||
}
|
||||
|
||||
let resp = agent
|
||||
.completion(prompt.clone(), chat_history.to_vec())
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
chat_history.push(prompt);
|
||||
|
||||
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
||||
.choice
|
||||
.iter()
|
||||
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
||||
|
||||
chat_history.push(Message::Assistant {
|
||||
content: resp.choice.clone(),
|
||||
});
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let merged_texts = texts
|
||||
.into_iter()
|
||||
.filter_map(|content| {
|
||||
if let AssistantContent::Text(text) = content {
|
||||
Some(text.text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if self.max_depth > 1 {
|
||||
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
||||
}
|
||||
|
||||
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
||||
return Ok(merged_texts);
|
||||
}
|
||||
|
||||
let tool_content = stream::iter(tool_calls)
|
||||
.then(async |choice| {
|
||||
if let AssistantContent::ToolCall(tool_call) = choice {
|
||||
let output = agent
|
||||
.tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await?;
|
||||
Ok(UserContent::tool_result(
|
||||
tool_call.id.clone(),
|
||||
OneOrMany::one(output.into()),
|
||||
))
|
||||
} else {
|
||||
unreachable!(
|
||||
"This should never happen as we already filtered for `ToolCall`"
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||
|
||||
prompt = Message::User {
|
||||
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
||||
};
|
||||
}
|
||||
|
||||
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
||||
Err(PromptError::MaxDepthError {
|
||||
max_depth: self.max_depth,
|
||||
chat_history: chat_history.clone(),
|
||||
prompt,
|
||||
})
|
||||
}
|
||||
}
|
|
@ -229,6 +229,16 @@ impl Message {
|
|||
content: OneOrMany::one(AssistantContent::text(text)),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper constructor to make creating tool result messages easier.
|
||||
pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Message::User {
|
||||
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
|
||||
id: id.into(),
|
||||
content: OneOrMany::one(ToolResultContent::text(content)),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserContent {
|
||||
|
@ -467,6 +477,12 @@ impl From<String> for Text {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&String> for Text {
|
||||
fn from(text: &String) -> Self {
|
||||
text.to_owned().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Text {
|
||||
fn from(text: &str) -> Self {
|
||||
text.to_owned().into()
|
||||
|
@ -497,6 +513,14 @@ impl From<&str> for Message {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&String> for Message {
|
||||
fn from(text: &String) -> Self {
|
||||
Message::User {
|
||||
content: OneOrMany::one(UserContent::Text(text.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Text> for Message {
|
||||
fn from(text: Text) -> Self {
|
||||
Message::User {
|
||||
|
|
|
@ -75,7 +75,7 @@ use crate::{
|
|||
tool::ToolSetError,
|
||||
};
|
||||
|
||||
use super::message::AssistantContent;
|
||||
use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
|
||||
|
||||
// Errors
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -108,6 +108,13 @@ pub enum PromptError {
|
|||
|
||||
#[error("ToolCallError: {0}")]
|
||||
ToolError(#[from] ToolSetError),
|
||||
|
||||
#[error("MaxDepthError: (reached limit: {max_depth})")]
|
||||
MaxDepthError {
|
||||
max_depth: usize,
|
||||
chat_history: Vec<Message>,
|
||||
prompt: Message,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
|
@ -163,7 +170,7 @@ pub trait Prompt: Send + Sync {
|
|||
fn prompt(
|
||||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
||||
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
|
||||
}
|
||||
|
||||
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
|
||||
|
@ -180,7 +187,7 @@ pub trait Chat: Send + Sync {
|
|||
&self,
|
||||
prompt: impl Into<Message> + Send,
|
||||
chat_history: Vec<Message>,
|
||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
||||
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
|
||||
}
|
||||
|
||||
/// Trait defining a low-level LLM completion interface
|
||||
|
@ -236,12 +243,11 @@ pub trait CompletionModel: Clone + Send + Sync {
|
|||
|
||||
/// Struct representing a general completion request that can be sent to a completion model provider.
|
||||
pub struct CompletionRequest {
|
||||
/// The prompt to be sent to the completion model provider
|
||||
pub prompt: Message,
|
||||
/// The preamble to be sent to the completion model provider
|
||||
pub preamble: Option<String>,
|
||||
/// The chat history to be sent to the completion model provider
|
||||
pub chat_history: Vec<Message>,
|
||||
/// The very last message will always be the prompt (hense why there is *always* one)
|
||||
pub chat_history: OneOrMany<Message>,
|
||||
/// The documents to be sent to the completion model provider
|
||||
pub documents: Vec<Document>,
|
||||
/// The tools to be sent to the completion model provider
|
||||
|
@ -255,23 +261,33 @@ pub struct CompletionRequest {
|
|||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn prompt_with_context(&self) -> Message {
|
||||
let mut new_prompt = self.prompt.clone();
|
||||
if let Message::User { ref mut content } = new_prompt {
|
||||
if !self.documents.is_empty() {
|
||||
let attachments = self
|
||||
.documents
|
||||
.iter()
|
||||
.map(|doc| doc.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
|
||||
let mut new_content = vec![UserContent::text(formatted_content)];
|
||||
new_content.extend(content.clone());
|
||||
*content = OneOrMany::many(new_content).expect("This has more than 1 item");
|
||||
}
|
||||
/// Returns documents normalized into a message (if any).
|
||||
/// Most providers do not accept documents directly as input, so it needs to convert into a
|
||||
/// `Message` so that it can be incorperated into `chat_history` as a
|
||||
pub fn normalized_documents(&self) -> Option<Message> {
|
||||
if self.documents.is_empty() {
|
||||
return None;
|
||||
}
|
||||
new_prompt
|
||||
|
||||
// Most providers will convert documents into a text unless it can handle document messages.
|
||||
// We use `UserContent::document` for those who handle it directly!
|
||||
let messages = self
|
||||
.documents
|
||||
.iter()
|
||||
.map(|doc| {
|
||||
UserContent::document(
|
||||
doc.to_string(),
|
||||
// In the future, we can customize `Document` to pass these extra types through.
|
||||
// Most providers ditch these but they might want to use them.
|
||||
Some(ContentFormat::String),
|
||||
Some(DocumentMediaType::TXT),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Some(Message::User {
|
||||
content: OneOrMany::many(messages).expect("There will be atleast one document"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -446,10 +462,12 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
|||
|
||||
/// Builds the completion request.
|
||||
pub fn build(self) -> CompletionRequest {
|
||||
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
|
||||
.expect("There will always be atleast the prompt");
|
||||
|
||||
CompletionRequest {
|
||||
prompt: self.prompt,
|
||||
preamble: self.preamble,
|
||||
chat_history: self.chat_history,
|
||||
chat_history,
|
||||
documents: self.documents,
|
||||
tools: self.tools,
|
||||
temperature: self.temperature,
|
||||
|
@ -475,7 +493,6 @@ impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::OneOrMany;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -513,7 +530,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_prompt_with_context_with_documents() {
|
||||
fn test_normalize_documents_with_documents() {
|
||||
let doc1 = Document {
|
||||
id: "doc1".to_string(),
|
||||
text: "Document 1 text.".to_string(),
|
||||
|
@ -527,9 +544,8 @@ mod tests {
|
|||
};
|
||||
|
||||
let request = CompletionRequest {
|
||||
prompt: "What is the capital of France?".into(),
|
||||
preamble: None,
|
||||
chat_history: Vec::new(),
|
||||
chat_history: OneOrMany::one("What is the capital of France?".into()),
|
||||
documents: vec![doc1, doc2],
|
||||
tools: Vec::new(),
|
||||
temperature: None,
|
||||
|
@ -539,19 +555,35 @@ mod tests {
|
|||
|
||||
let expected = Message::User {
|
||||
content: OneOrMany::many(vec![
|
||||
UserContent::text(concat!(
|
||||
"<attachments>\n",
|
||||
"<file id: doc1>\nDocument 1 text.\n</file>\n",
|
||||
"<file id: doc2>\nDocument 2 text.\n</file>\n",
|
||||
"</attachments>"
|
||||
)),
|
||||
UserContent::text("What is the capital of France?"),
|
||||
UserContent::document(
|
||||
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
|
||||
Some(ContentFormat::String),
|
||||
Some(DocumentMediaType::TXT),
|
||||
),
|
||||
UserContent::document(
|
||||
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
|
||||
Some(ContentFormat::String),
|
||||
Some(DocumentMediaType::TXT),
|
||||
),
|
||||
])
|
||||
.expect("This has more than 1 item"),
|
||||
.expect("There will be at least one document"),
|
||||
};
|
||||
|
||||
request.prompt_with_context();
|
||||
assert_eq!(request.normalized_documents(), Some(expected));
|
||||
}
|
||||
|
||||
assert_eq!(request.prompt_with_context(), expected);
|
||||
#[test]
|
||||
fn test_normalize_documents_without_documents() {
|
||||
let request = CompletionRequest {
|
||||
preamble: None,
|
||||
chat_history: OneOrMany::one("What is the capital of France?".into()),
|
||||
documents: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.normalized_documents(), None);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,10 +36,13 @@ use serde_json::json;
|
|||
|
||||
use crate::{
|
||||
agent::{Agent, AgentBuilder},
|
||||
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||
completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
|
||||
message::{AssistantContent, Message, ToolCall, ToolFunction},
|
||||
tool::Tool,
|
||||
};
|
||||
|
||||
const SUBMIT_TOOL_NAME: &str = "submit";
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ExtractionError {
|
||||
#[error("No data extracted")]
|
||||
|
@ -48,8 +51,8 @@ pub enum ExtractionError {
|
|||
#[error("Failed to deserialize the extracted data: {0}")]
|
||||
DeserializationError(#[from] serde_json::Error),
|
||||
|
||||
#[error("PromptError: {0}")]
|
||||
PromptError(#[from] PromptError),
|
||||
#[error("CompletionError: {0}")]
|
||||
CompletionError(#[from] CompletionError),
|
||||
}
|
||||
|
||||
/// Extractor for structured data from text
|
||||
|
@ -62,14 +65,43 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
|||
where
|
||||
M: Sync,
|
||||
{
|
||||
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
|
||||
let summary = self.agent.prompt(text).await?;
|
||||
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
|
||||
let response = self.agent.completion(text, vec![]).await?.send().await?;
|
||||
|
||||
if summary.is_empty() {
|
||||
return Err(ExtractionError::NoData);
|
||||
let arguments = response
|
||||
.choice
|
||||
.into_iter()
|
||||
// We filter tool calls to look for submit tool calls
|
||||
.filter_map(|content| {
|
||||
if let AssistantContent::ToolCall(ToolCall {
|
||||
function: ToolFunction { arguments, name },
|
||||
..
|
||||
}) = content
|
||||
{
|
||||
if name == SUBMIT_TOOL_NAME {
|
||||
Some(arguments)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if arguments.len() > 1 {
|
||||
tracing::warn!(
|
||||
"Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&summary)?)
|
||||
let raw_data = if let Some(arg) = arguments.into_iter().next() {
|
||||
arg
|
||||
} else {
|
||||
return Err(ExtractionError::NoData);
|
||||
};
|
||||
|
||||
Ok(serde_json::from_value(raw_data)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,7 +164,7 @@ struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
|
|||
struct SubmitError;
|
||||
|
||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
|
||||
const NAME: &'static str = "submit";
|
||||
const NAME: &'static str = SUBMIT_TOOL_NAME;
|
||||
type Error = SubmitError;
|
||||
type Args = T;
|
||||
type Output = T;
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::future::IntoFuture;
|
||||
|
||||
use crate::{
|
||||
completion::{self, CompletionModel},
|
||||
extractor::{ExtractionError, Extractor},
|
||||
message::Message,
|
||||
vector_store,
|
||||
};
|
||||
|
||||
|
@ -79,14 +82,14 @@ impl<P, In> Prompt<P, In> {
|
|||
|
||||
impl<P, In> Op for Prompt<P, In>
|
||||
where
|
||||
P: completion::Prompt,
|
||||
P: completion::Prompt + Send + Sync,
|
||||
In: Into<String> + Send + Sync,
|
||||
{
|
||||
type Input = In;
|
||||
type Output = Result<String, completion::PromptError>;
|
||||
|
||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
||||
self.prompt.prompt(input.into()).await
|
||||
fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
|
||||
self.prompt.prompt(input.into()).into_future()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,13 +130,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
|
|||
where
|
||||
M: CompletionModel,
|
||||
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
||||
Input: Into<String> + Send + Sync,
|
||||
Input: Into<Message> + Send + Sync,
|
||||
{
|
||||
type Input = Input;
|
||||
type Output = Result<Output, ExtractionError>;
|
||||
|
||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
||||
self.extractor.extract(&input.into()).await
|
||||
self.extractor.extract(input).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,6 +162,7 @@ pub mod tests {
|
|||
pub struct MockModel;
|
||||
|
||||
impl Prompt for MockModel {
|
||||
#[allow(refining_impl_trait)]
|
||||
async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
|
||||
let msg: message::Message = prompt.into();
|
||||
let prompt = match msg {
|
||||
|
|
|
@ -553,26 +553,20 @@ impl completion::CompletionModel for CompletionModel {
|
|||
));
|
||||
};
|
||||
|
||||
let prompt_message: Message = completion_request
|
||||
.prompt_with_context()
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||
let mut full_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
full_history.push(docs);
|
||||
}
|
||||
full_history.extend(completion_request.chat_history);
|
||||
|
||||
let mut messages = completion_request
|
||||
.chat_history
|
||||
let full_history = full_history
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||
})
|
||||
.map(Message::try_from)
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
messages.push(prompt_message);
|
||||
|
||||
let mut request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"messages": full_history,
|
||||
"max_tokens": max_tokens,
|
||||
"system": completion_request.preamble.unwrap_or("".to_string()),
|
||||
});
|
||||
|
|
|
@ -7,7 +7,6 @@ use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefin
|
|||
use super::decoders::sse::from_response as sse_from_response;
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message::MessageError;
|
||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -90,26 +89,20 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
));
|
||||
};
|
||||
|
||||
let prompt_message: Message = completion_request
|
||||
.prompt_with_context()
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
||||
let mut full_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
full_history.push(docs);
|
||||
}
|
||||
full_history.extend(completion_request.chat_history);
|
||||
|
||||
let mut messages = completion_request
|
||||
.chat_history
|
||||
let full_history = full_history
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
||||
})
|
||||
.map(Message::try_from)
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
messages.push(prompt_message);
|
||||
|
||||
let mut request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"messages": full_history,
|
||||
"max_tokens": max_tokens,
|
||||
"system": completion_request.preamble.unwrap_or("".to_string()),
|
||||
"stream": true,
|
||||
|
|
|
@ -480,16 +480,14 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<serde_json::Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![openai::Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
||||
|
||||
// Convert existing chat history
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
let docs: Vec<openai::Message> = docs.try_into()?;
|
||||
full_history.extend(docs);
|
||||
}
|
||||
let chat_history: Vec<openai::Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
|
@ -499,9 +497,7 @@ impl CompletionModel {
|
|||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
@ -786,6 +782,7 @@ mod azure_tests {
|
|||
|
||||
use crate::completion::CompletionModel;
|
||||
use crate::embeddings::EmbeddingModel;
|
||||
use crate::OneOrMany;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
|
@ -812,8 +809,7 @@ mod azure_tests {
|
|||
let completion = model
|
||||
.completion(CompletionRequest {
|
||||
preamble: Some("You are a helpful assistant.".to_string()),
|
||||
chat_history: vec![],
|
||||
prompt: "Hello, world!".into(),
|
||||
chat_history: OneOrMany::one("Hello!".into()),
|
||||
documents: vec![],
|
||||
max_tokens: Some(100),
|
||||
temperature: Some(0.0),
|
||||
|
|
|
@ -440,29 +440,34 @@ impl completion::CompletionModel for CompletionModel {
|
|||
&self,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let prompt = completion_request.prompt_with_context();
|
||||
// Build up the order of messages (context, chat_history)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
let mut messages: Vec<message::Message> =
|
||||
if let Some(preamble) = completion_request.preamble {
|
||||
vec![preamble.into()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| {
|
||||
vec![Message::System { content: preamble }]
|
||||
});
|
||||
|
||||
messages.extend(completion_request.chat_history);
|
||||
messages.push(prompt);
|
||||
|
||||
let messages: Vec<Message> = messages
|
||||
.into_iter()
|
||||
.map(|msg| msg.try_into())
|
||||
.collect::<Result<Vec<Vec<_>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"messages": full_history,
|
||||
"documents": completion_request.documents,
|
||||
"temperature": completion_request.temperature,
|
||||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||
|
|
|
@ -379,28 +379,28 @@ impl DeepSeekCompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<serde_json::Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history, prompt)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -398,6 +398,13 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Build up the order of messages (context, chat_history, prompt)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message {
|
||||
|
@ -408,19 +415,13 @@ impl CompletionModel {
|
|||
None => vec![],
|
||||
};
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.push(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Message>, _>>()?,
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -93,11 +93,10 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
|
||||
pub(crate) fn create_request_body(
|
||||
mut completion_request: CompletionRequest,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<GenerateContentRequest, CompletionError> {
|
||||
let mut full_history = Vec::new();
|
||||
full_history.append(&mut completion_request.chat_history);
|
||||
full_history.push(completion_request.prompt_with_context());
|
||||
full_history.extend(completion_request.chat_history);
|
||||
|
||||
let additional_params = completion_request
|
||||
.additional_params
|
||||
|
|
|
@ -279,28 +279,31 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(preamble.to_string()),
|
||||
}],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history, prompt)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> =
|
||||
completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| {
|
||||
vec![Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(preamble),
|
||||
}]
|
||||
});
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Message>, _>>()?;
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.push(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Message>, _>>()?,
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -502,8 +502,10 @@ impl CompletionModel {
|
|||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
let docs: Vec<Message> = docs.try_into()?;
|
||||
full_history.extend(docs);
|
||||
}
|
||||
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
|
@ -516,7 +518,6 @@ impl CompletionModel {
|
|||
.collect();
|
||||
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let model = self.client.sub_provider.model_identifier(&self.model);
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::message;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
|
@ -306,28 +307,28 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history, prompt)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
|
|
|
@ -238,24 +238,25 @@ impl CompletionModel {
|
|||
}));
|
||||
}
|
||||
|
||||
// Add prompt
|
||||
messages.push(match &completion_request.prompt {
|
||||
Message::User { content } => {
|
||||
let text = content
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
UserContent::Text(text) => &text.text,
|
||||
_ => "",
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
serde_json::json!({
|
||||
"role": "user",
|
||||
"content": text
|
||||
// Add docs
|
||||
if let Some(Message::User { content }) = completion_request.normalized_documents() {
|
||||
let text = content
|
||||
.into_iter()
|
||||
.filter_map(|doc| match doc {
|
||||
UserContent::Document(doc) => Some(doc.data),
|
||||
UserContent::Text(text) => Some(text.text),
|
||||
|
||||
// This should always be `Document`
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
messages.push(serde_json::json!({
|
||||
"role": "user",
|
||||
"content": text
|
||||
}));
|
||||
}
|
||||
|
||||
// Add chat history
|
||||
for msg in completion_request.chat_history {
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
//! ```
|
||||
|
||||
use crate::json_utils::merge;
|
||||
use crate::message;
|
||||
use crate::providers::openai::send_compatible_streaming_request;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::{
|
||||
|
@ -141,28 +142,30 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![openai::Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<openai::Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| {
|
||||
vec![openai::Message::system(&preamble)]
|
||||
});
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<openai::Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -325,8 +325,27 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Build up the order of messages (context, chat_history)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(|msg| msg.try_into())
|
||||
.collect::<Result<Vec<Message>, _>>()?,
|
||||
);
|
||||
|
||||
// Convert internal prompt into a provider Message
|
||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
||||
let options = if let Some(extra) = completion_request.additional_params {
|
||||
json_utils::merge(
|
||||
json!({ "temperature": completion_request.temperature }),
|
||||
|
@ -336,16 +355,6 @@ impl CompletionModel {
|
|||
json!({ "temperature": completion_request.temperature })
|
||||
};
|
||||
|
||||
// Chat mode: assemble full conversation history including preamble and chat history
|
||||
let mut full_history = Vec::new();
|
||||
if let Some(preamble) = completion_request.preamble {
|
||||
full_history.push(Message::system(&preamble));
|
||||
}
|
||||
for msg in completion_request.chat_history.into_iter() {
|
||||
full_history.push(Message::try_from(msg)?);
|
||||
}
|
||||
full_history.push(prompt);
|
||||
|
||||
let mut request_payload = json!({
|
||||
"model": self.model,
|
||||
"messages": full_history,
|
||||
|
|
|
@ -605,28 +605,28 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -269,8 +269,11 @@ impl completion::CompletionModel for CompletionModel {
|
|||
None => vec![],
|
||||
};
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Gather docs
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
let docs: Vec<Message> = docs.try_into()?;
|
||||
full_history.extend(docs);
|
||||
}
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
|
@ -284,7 +287,6 @@ impl completion::CompletionModel for CompletionModel {
|
|||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
|
|
|
@ -204,39 +204,36 @@ impl CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add context documents to current prompt
|
||||
let prompt_with_context = completion_request.prompt_with_context();
|
||||
|
||||
// Add preamble to messages (if available)
|
||||
let mut messages: Vec<Message> = if let Some(preamble) = completion_request.preamble {
|
||||
vec![Message {
|
||||
role: Role::System,
|
||||
content: preamble,
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Add chat history to messages
|
||||
for message in completion_request.chat_history {
|
||||
messages.push(
|
||||
message
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
|
||||
);
|
||||
// Build up the order of messages (context, chat_history, prompt)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Add user prompt to messages
|
||||
messages.push(
|
||||
prompt_with_context
|
||||
.try_into()
|
||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> =
|
||||
completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| {
|
||||
vec![Message {
|
||||
role: Role::System,
|
||||
content: preamble,
|
||||
}]
|
||||
});
|
||||
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Message>, _>>()?,
|
||||
);
|
||||
|
||||
// Compose request
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"messages": full_history,
|
||||
"temperature": completion_request.temperature,
|
||||
});
|
||||
|
||||
|
|
|
@ -148,7 +148,10 @@ impl CompletionModel {
|
|||
Some(preamble) => vec![openai::Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
let docs: Vec<openai::Message> = docs.try_into()?;
|
||||
full_history.extend(docs);
|
||||
}
|
||||
let chat_history: Vec<openai::Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
|
@ -157,8 +160,9 @@ impl CompletionModel {
|
|||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let mut request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
"model": self.model,
|
||||
|
|
|
@ -10,7 +10,6 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::client::{xai_api_types::ApiResponse, Client};
|
||||
use crate::completion::CompletionRequest;
|
||||
use serde_json::{json, Value};
|
||||
use xai_api_types::{CompletionResponse, ToolDefinition};
|
||||
|
||||
|
@ -30,22 +29,13 @@ pub struct CompletionModel {
|
|||
impl CompletionModel {
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
completion_request: completion::CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => {
|
||||
if preamble.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
vec![Message::system(preamble)]
|
||||
}
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
// Convert documents into user message
|
||||
let docs: Option<Vec<Message>> = completion_request
|
||||
.normalized_documents()
|
||||
.map(|docs| docs.try_into())
|
||||
.transpose()?;
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
|
@ -57,9 +47,19 @@ impl CompletionModel {
|
|||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
// Init full history with preamble (or empty if non-existant)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
// Docs appear right after preamble, if they exist
|
||||
if let Some(docs) = docs {
|
||||
full_history.extend(docs)
|
||||
}
|
||||
|
||||
// Chat history and prompt appear in the order they were provided
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let mut request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
|
@ -15,6 +15,7 @@ use rig::agent::AgentBuilder;
|
|||
use rig::completion::{CompletionError, CompletionRequest};
|
||||
use rig::embeddings::{EmbeddingError, EmbeddingsBuilder};
|
||||
use rig::extractor::ExtractorBuilder;
|
||||
use rig::message;
|
||||
use rig::providers::openai::{self, Message};
|
||||
use rig::OneOrMany;
|
||||
use rig::{completion, embeddings, Embed};
|
||||
|
@ -470,14 +471,19 @@ impl completion::CompletionModel for CompletionModel {
|
|||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
// Build up the order of messages (context, chat_history)
|
||||
let mut partial_history = vec![];
|
||||
if let Some(docs) = completion_request.normalized_documents() {
|
||||
partial_history.push(docs);
|
||||
}
|
||||
partial_history.extend(completion_request.chat_history);
|
||||
|
||||
// Initialize full history with preamble (or empty if non-existent)
|
||||
let mut full_history: Vec<Message> = completion_request
|
||||
.preamble
|
||||
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
tracing::info!("Try to get on-chain system prompt");
|
||||
let eternal_ai_rpc = std::env::var("ETERNALAI_RPC_URL").unwrap_or_else(|_| "".to_string());
|
||||
let eternal_ai_contract =
|
||||
|
@ -515,19 +521,16 @@ impl completion::CompletionModel for CompletionModel {
|
|||
}
|
||||
}
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
// Convert and extend the rest of the history
|
||||
full_history.extend(
|
||||
partial_history
|
||||
.into_iter()
|
||||
.map(message::Message::try_into)
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
|
|
Loading…
Reference in New Issue