mirror of https://github.com/0xplaygrounds/rig
Merge 1dd15d8f8c
into 33e8fc7a65
This commit is contained in:
commit
7245a06922
|
@ -120,13 +120,13 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
.model_id(self.model.as_str());
|
.model_id(self.model.as_str());
|
||||||
|
|
||||||
let tool_config = request.tools_config()?;
|
let tool_config = request.tools_config()?;
|
||||||
let prompt_with_history = request.prompt_with_history()?;
|
let messages = request.messages()?;
|
||||||
converse_builder = converse_builder
|
converse_builder = converse_builder
|
||||||
.set_additional_model_request_fields(request.additional_params())
|
.set_additional_model_request_fields(request.additional_params())
|
||||||
.set_inference_config(request.inference_config())
|
.set_inference_config(request.inference_config())
|
||||||
.set_tool_config(tool_config)
|
.set_tool_config(tool_config)
|
||||||
.set_system(request.system_prompt())
|
.set_system(request.system_prompt())
|
||||||
.set_messages(Some(prompt_with_history));
|
.set_messages(Some(messages));
|
||||||
|
|
||||||
let response = converse_builder
|
let response = converse_builder
|
||||||
.send()
|
.send()
|
||||||
|
|
|
@ -28,7 +28,7 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
.model_id(self.model.as_str());
|
.model_id(self.model.as_str());
|
||||||
|
|
||||||
let tool_config = request.tools_config()?;
|
let tool_config = request.tools_config()?;
|
||||||
let prompt_with_history = request.prompt_with_history()?;
|
let prompt_with_history = request.messages()?;
|
||||||
converse_builder = converse_builder
|
converse_builder = converse_builder
|
||||||
.set_additional_model_request_fields(request.additional_params())
|
.set_additional_model_request_fields(request.additional_params())
|
||||||
.set_inference_config(request.inference_config())
|
.set_inference_config(request.inference_config())
|
||||||
|
|
|
@ -6,6 +6,8 @@ use aws_sdk_bedrockruntime::types::{
|
||||||
ToolSpecification,
|
ToolSpecification,
|
||||||
};
|
};
|
||||||
use rig::completion::{CompletionError, Message};
|
use rig::completion::{CompletionError, Message};
|
||||||
|
use rig::message::{ContentFormat, DocumentMediaType, UserContent};
|
||||||
|
use rig::OneOrMany;
|
||||||
|
|
||||||
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
|
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
|
||||||
|
|
||||||
|
@ -69,13 +71,30 @@ impl AwsCompletionRequest {
|
||||||
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
|
.map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn prompt_with_history(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
|
pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
|
||||||
let mut chat_history = self.0.chat_history.to_owned();
|
|
||||||
let prompt_with_context = self.0.prompt_with_context();
|
|
||||||
|
|
||||||
let mut full_history: Vec<Message> = Vec::new();
|
let mut full_history: Vec<Message> = Vec::new();
|
||||||
full_history.append(&mut chat_history);
|
|
||||||
full_history.push(prompt_with_context);
|
if !self.0.documents.is_empty() {
|
||||||
|
let messages = self
|
||||||
|
.0
|
||||||
|
.documents
|
||||||
|
.iter()
|
||||||
|
.map(|doc| doc.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" | ");
|
||||||
|
|
||||||
|
let content = OneOrMany::one(UserContent::document(
|
||||||
|
messages,
|
||||||
|
Some(ContentFormat::String),
|
||||||
|
Some(DocumentMediaType::TXT),
|
||||||
|
));
|
||||||
|
|
||||||
|
full_history.push(Message::User { content });
|
||||||
|
}
|
||||||
|
|
||||||
|
self.0.chat_history.iter().for_each(|message| {
|
||||||
|
full_history.push(message.clone());
|
||||||
|
});
|
||||||
|
|
||||||
full_history
|
full_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -11,7 +11,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
|
||||||
// Create agent with a single context prompt
|
// Create agent with a single context prompt
|
||||||
let comedian_agent = client
|
let comedian_agent = client
|
||||||
.agent("cognitivecomputations/dolphin3.0-mistral-24b:free")
|
.agent("google/gemini-2.5-pro-exp-03-25:free")
|
||||||
.preamble("You are a comedian here to entertain the user using humour and jokes.")
|
.preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::env;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use rig::{
|
use rig::{
|
||||||
agent::Agent,
|
agent::Agent,
|
||||||
completion::Chat,
|
completion::Prompt,
|
||||||
message::Message,
|
message::Message,
|
||||||
providers::{cohere, openai},
|
providers::{cohere, openai},
|
||||||
};
|
};
|
||||||
|
@ -49,20 +49,20 @@ impl Debater {
|
||||||
|
|
||||||
let resp_a = self
|
let resp_a = self
|
||||||
.gpt_4
|
.gpt_4
|
||||||
.chat(prompt_a.as_str(), history_a.clone())
|
.prompt(prompt_a.as_str())
|
||||||
|
.with_history(&mut history_a)
|
||||||
.await?;
|
.await?;
|
||||||
println!("GPT-4:\n{}", resp_a);
|
println!("GPT-4:\n{}", resp_a);
|
||||||
history_a.push(Message::user(prompt_a));
|
|
||||||
history_a.push(Message::assistant(resp_a.clone()));
|
|
||||||
println!("================================================================");
|
println!("================================================================");
|
||||||
|
|
||||||
let resp_b = self.coral.chat(resp_a.as_str(), history_b.clone()).await?;
|
let resp_b = self
|
||||||
|
.coral
|
||||||
|
.prompt(resp_a.as_str())
|
||||||
|
.with_history(&mut history_b)
|
||||||
|
.await?;
|
||||||
println!("Coral:\n{}", resp_b);
|
println!("Coral:\n{}", resp_b);
|
||||||
println!("================================================================");
|
println!("================================================================");
|
||||||
|
|
||||||
history_b.push(Message::user(resp_a));
|
|
||||||
history_b.push(Message::assistant(resp_b.clone()));
|
|
||||||
|
|
||||||
last_resp_b = Some(resp_b)
|
last_resp_b = Some(resp_b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
|
impl<M: CompletionModel> Chat for EnglishTranslator<M> {
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
async fn chat(
|
async fn chat(
|
||||||
&self,
|
&self,
|
||||||
prompt: impl Into<Message> + Send,
|
prompt: impl Into<Message> + Send,
|
||||||
|
|
|
@ -1,99 +1,23 @@
|
||||||
use rig::{
|
use rig::{
|
||||||
agent::Agent,
|
completion::{Prompt, ToolDefinition},
|
||||||
completion::{self, Completion, PromptError, ToolDefinition},
|
|
||||||
message::{AssistantContent, Message, ToolCall, ToolFunction, ToolResultContent, UserContent},
|
|
||||||
providers::anthropic,
|
providers::anthropic,
|
||||||
tool::Tool,
|
tool::Tool,
|
||||||
OneOrMany,
|
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
struct MultiTurnAgent<M: rig::completion::CompletionModel> {
|
|
||||||
agent: Agent<M>,
|
|
||||||
chat_history: Vec<completion::Message>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: rig::completion::CompletionModel> MultiTurnAgent<M> {
|
|
||||||
async fn multi_turn_prompt(
|
|
||||||
&mut self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
) -> Result<String, PromptError> {
|
|
||||||
let mut current_prompt: Message = prompt.into();
|
|
||||||
loop {
|
|
||||||
println!("Current Prompt: {:?}\n", current_prompt);
|
|
||||||
let resp = self
|
|
||||||
.agent
|
|
||||||
.completion(current_prompt.clone(), self.chat_history.clone())
|
|
||||||
.await?
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut final_text = None;
|
|
||||||
|
|
||||||
for content in resp.choice.into_iter() {
|
|
||||||
match content {
|
|
||||||
AssistantContent::Text(text) => {
|
|
||||||
println!("Intermediate Response: {:?}\n", text.text);
|
|
||||||
final_text = Some(text.text.clone());
|
|
||||||
self.chat_history.push(current_prompt.clone());
|
|
||||||
let response_message = Message::Assistant {
|
|
||||||
content: OneOrMany::one(AssistantContent::text(&text.text)),
|
|
||||||
};
|
|
||||||
self.chat_history.push(response_message);
|
|
||||||
}
|
|
||||||
AssistantContent::ToolCall(content) => {
|
|
||||||
self.chat_history.push(current_prompt.clone());
|
|
||||||
let tool_call_msg = AssistantContent::ToolCall(content.clone());
|
|
||||||
println!("Tool Call Msg: {:?}\n", tool_call_msg);
|
|
||||||
|
|
||||||
self.chat_history.push(Message::Assistant {
|
|
||||||
content: OneOrMany::one(tool_call_msg),
|
|
||||||
});
|
|
||||||
|
|
||||||
let ToolCall {
|
|
||||||
id,
|
|
||||||
function: ToolFunction { name, arguments },
|
|
||||||
} = content;
|
|
||||||
|
|
||||||
let tool_result =
|
|
||||||
self.agent.tools.call(&name, arguments.to_string()).await?;
|
|
||||||
|
|
||||||
current_prompt = Message::User {
|
|
||||||
content: OneOrMany::one(UserContent::tool_result(
|
|
||||||
id,
|
|
||||||
OneOrMany::one(ToolResultContent::text(tool_result)),
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
final_text = None;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(text) = final_text {
|
|
||||||
return Ok(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
// tracing_subscriber::registry()
|
tracing_subscriber::fmt()
|
||||||
// .with(
|
.with_max_level(tracing::Level::DEBUG)
|
||||||
// tracing_subscriber::EnvFilter::try_from_default_env()
|
.with_target(false)
|
||||||
// .unwrap_or_else(|_| "stdout=info".into()),
|
.init();
|
||||||
// )
|
|
||||||
// .with(tracing_subscriber::fmt::layer())
|
|
||||||
// .init();
|
|
||||||
|
|
||||||
// Create OpenAI client
|
// Create OpenAI client
|
||||||
let openai_client = anthropic::Client::from_env();
|
let openai_client = anthropic::Client::from_env();
|
||||||
|
|
||||||
// Create RAG agent with a single context prompt and a dynamic tool source
|
// Create RAG agent with a single context prompt and a dynamic tool source
|
||||||
let calculator_rag = openai_client
|
let agent = openai_client
|
||||||
.agent(anthropic::CLAUDE_3_5_SONNET)
|
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||||
.preamble(
|
.preamble(
|
||||||
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||||
|
@ -109,21 +33,18 @@ async fn main() -> anyhow::Result<()> {
|
||||||
.tool(Divide)
|
.tool(Divide)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
let mut agent = MultiTurnAgent {
|
|
||||||
agent: calculator_rag,
|
|
||||||
chat_history: Vec::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Prompt the agent and print the response
|
// Prompt the agent and print the response
|
||||||
let result = agent
|
let result = agent
|
||||||
.multi_turn_prompt("Calculate 5 - 2 = ?. Describe the result to me.")
|
.prompt("Calculate 5 - 2 = ?. Describe the result to me.")
|
||||||
|
.multi_turn(20)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("\n\nOpenAI Calculator Agent: {}", result);
|
println!("\n\nOpenAI Calculator Agent: {}", result);
|
||||||
|
|
||||||
// Prompt the agent again and print the response
|
// Prompt the agent again and print the response
|
||||||
let result = agent
|
let result = agent
|
||||||
.multi_turn_prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
|
.prompt("Calculate (3 + 5) / 9 = ?. Describe the result to me.")
|
||||||
|
.multi_turn(20)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("\n\nOpenAI Calculator Agent: {}", result);
|
println!("\n\nOpenAI Calculator Agent: {}", result);
|
||||||
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
use rig::{
|
||||||
|
agent::Agent,
|
||||||
|
completion::{CompletionError, CompletionModel, Prompt, PromptError, ToolDefinition},
|
||||||
|
extractor::Extractor,
|
||||||
|
message::Message,
|
||||||
|
providers::anthropic,
|
||||||
|
tool::Tool,
|
||||||
|
};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
const CHAIN_OF_THOUGHT_PROMPT: &str = "
|
||||||
|
You are an assistant that extracts reasoning steps from a given prompt.
|
||||||
|
Do not return text, only return a tool call.
|
||||||
|
";
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
|
||||||
|
struct ChainOfThoughtSteps {
|
||||||
|
steps: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReasoningAgent<M: CompletionModel> {
|
||||||
|
chain_of_thought_extractor: Extractor<M, ChainOfThoughtSteps>,
|
||||||
|
executor: Agent<M>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> Prompt for ReasoningAgent<M> {
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
||||||
|
let prompt: Message = prompt.into();
|
||||||
|
let mut chat_history = vec![prompt.clone()];
|
||||||
|
let extracted = self
|
||||||
|
.chain_of_thought_extractor
|
||||||
|
.extract(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Extraction error: {:?}", e);
|
||||||
|
CompletionError::ProviderError("".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if extracted.steps.is_empty() {
|
||||||
|
return Ok("No reasoning steps provided.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut reasoning_prompt = String::new();
|
||||||
|
for (i, step) in extracted.steps.iter().enumerate() {
|
||||||
|
reasoning_prompt.push_str(&format!("Step {}: {}\n", i + 1, step));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.executor
|
||||||
|
.prompt(reasoning_prompt.as_str())
|
||||||
|
.with_history(&mut chat_history)
|
||||||
|
.multi_turn(20)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"full chat history generated: {}",
|
||||||
|
serde_json::to_string_pretty(&chat_history).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_max_level(tracing::Level::DEBUG)
|
||||||
|
.with_target(false)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// Create OpenAI client
|
||||||
|
let anthropic_client = anthropic::Client::from_env();
|
||||||
|
|
||||||
|
let agent = ReasoningAgent {
|
||||||
|
chain_of_thought_extractor: anthropic_client
|
||||||
|
.extractor(anthropic::CLAUDE_3_5_SONNET)
|
||||||
|
.preamble(CHAIN_OF_THOUGHT_PROMPT)
|
||||||
|
.build(),
|
||||||
|
|
||||||
|
executor: anthropic_client
|
||||||
|
.agent(anthropic::CLAUDE_3_5_SONNET)
|
||||||
|
.preamble(
|
||||||
|
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
|
||||||
|
Follow these instructions closely.
|
||||||
|
1. Consider the user's request carefully and identify the core elements of the request.
|
||||||
|
2. Select which tool among those made available to you is appropriate given the context.
|
||||||
|
3. This is very important: never perform the operation yourself.
|
||||||
|
4. When you think you've finished calling tools for the operation, present the final result from the series of tool calls you made.
|
||||||
|
"
|
||||||
|
)
|
||||||
|
.tool(Add)
|
||||||
|
.tool(Subtract)
|
||||||
|
.tool(Multiply)
|
||||||
|
.tool(Divide)
|
||||||
|
.build(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Prompt the agent and print the response
|
||||||
|
let result = agent
|
||||||
|
.prompt("Calculate ((15 + 25) * (100 - 50)) / (200 / (10 + 10))")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("\n\nReasoning Agent Chat History: {}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Add;
|
||||||
|
impl Tool for Add {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "add",
|
||||||
|
"description": "Add x and y together",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Subtract;
|
||||||
|
impl Tool for Subtract {
|
||||||
|
const NAME: &'static str = "subtract";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "subtract",
|
||||||
|
"description": "Subtract y from x (i.e.: x - y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract from"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The number to subtract"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x - args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Multiply;
|
||||||
|
impl Tool for Multiply {
|
||||||
|
const NAME: &'static str = "multiply";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "Compute the product of x and y (i.e.: x * y)",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first factor in the product"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second factor in the product"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x * args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Divide;
|
||||||
|
impl Tool for Divide {
|
||||||
|
const NAME: &'static str = "divide";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "divide",
|
||||||
|
"description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The Dividend of the division. The number being divided"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The Divisor of the division. The number by which the dividend is being divided"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x / args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,520 +0,0 @@
|
||||||
//! This module contains the implementation of the [Agent] struct and its builder.
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
|
||||||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
|
||||||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
|
||||||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
|
||||||
//! context documents and tools.
|
|
||||||
//!
|
|
||||||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
|
|
||||||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
|
|
||||||
//! be used for generating chat completions.
|
|
||||||
//!
|
|
||||||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
|
||||||
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
|
||||||
//! before building the agent.
|
|
||||||
//!
|
|
||||||
//! # Example
|
|
||||||
//! ```rust
|
|
||||||
//! use rig::{
|
|
||||||
//! completion::{Chat, Completion, Prompt},
|
|
||||||
//! providers::openai,
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! let openai = openai::Client::from_env();
|
|
||||||
//!
|
|
||||||
//! // Configure the agent
|
|
||||||
//! let agent = openai.agent("gpt-4o")
|
|
||||||
//! .preamble("System prompt")
|
|
||||||
//! .context("Context document 1")
|
|
||||||
//! .context("Context document 2")
|
|
||||||
//! .tool(tool1)
|
|
||||||
//! .tool(tool2)
|
|
||||||
//! .temperature(0.8)
|
|
||||||
//! .additional_params(json!({"foo": "bar"}))
|
|
||||||
//! .build();
|
|
||||||
//!
|
|
||||||
//! // Use the agent for completions and prompts
|
|
||||||
//! // Generate a chat completion response from a prompt and chat history
|
|
||||||
//! let chat_response = agent.chat("Prompt", chat_history)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to chat with Agent");
|
|
||||||
//!
|
|
||||||
//! // Generate a prompt completion response from a simple prompt
|
|
||||||
//! let chat_response = agent.prompt("Prompt")
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to prompt the Agent");
|
|
||||||
//!
|
|
||||||
//! // Generate a completion request builder from a prompt and chat history. The builder
|
|
||||||
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
|
||||||
//! // model parameters, etc.), but these can be overwritten.
|
|
||||||
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to create completion request builder");
|
|
||||||
//!
|
|
||||||
//! let response = completion_req_builder
|
|
||||||
//! .temperature(0.9) // Overwrite the agent's temperature
|
|
||||||
//! .send()
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to send completion request");
|
|
||||||
//! ```
|
|
||||||
//!
|
|
||||||
//! RAG Agent example
|
|
||||||
//! ```rust
|
|
||||||
//! use rig::{
|
|
||||||
//! completion::Prompt,
|
|
||||||
//! embeddings::EmbeddingsBuilder,
|
|
||||||
//! providers::openai,
|
|
||||||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! // Initialize OpenAI client
|
|
||||||
//! let openai = openai::Client::from_env();
|
|
||||||
//!
|
|
||||||
//! // Initialize OpenAI embedding model
|
|
||||||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
|
||||||
//!
|
|
||||||
//! // Create vector store, compute embeddings and load them in the store
|
|
||||||
//! let mut vector_store = InMemoryVectorStore::default();
|
|
||||||
//!
|
|
||||||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
|
||||||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
|
||||||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
|
||||||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
|
||||||
//! .build()
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to build embeddings");
|
|
||||||
//!
|
|
||||||
//! vector_store.add_documents(embeddings)
|
|
||||||
//! .await
|
|
||||||
//! .expect("Failed to add documents");
|
|
||||||
//!
|
|
||||||
//! // Create vector store index
|
|
||||||
//! let index = vector_store.index(embedding_model);
|
|
||||||
//!
|
|
||||||
//! let agent = openai.agent(openai::GPT_4O)
|
|
||||||
//! .preamble("
|
|
||||||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
|
||||||
//! You will find additional non-standard word definitions that could be useful below.
|
|
||||||
//! ")
|
|
||||||
//! .dynamic_context(1, index)
|
|
||||||
//! .build();
|
|
||||||
//!
|
|
||||||
//! // Prompt the agent and print the response
|
|
||||||
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
|
||||||
//! .expect("Failed to prompt the agent");
|
|
||||||
//! ```
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use futures::{stream, StreamExt, TryStreamExt};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
completion::{
|
|
||||||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
|
||||||
Message, Prompt, PromptError,
|
|
||||||
},
|
|
||||||
message::AssistantContent,
|
|
||||||
streaming::{
|
|
||||||
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
|
||||||
StreamingResult,
|
|
||||||
},
|
|
||||||
tool::{Tool, ToolSet},
|
|
||||||
vector_store::{VectorStoreError, VectorStoreIndexDyn},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
use crate::tool::McpTool;
|
|
||||||
|
|
||||||
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
|
||||||
/// (i.e.: system prompt) and a static set of context documents and tools.
|
|
||||||
/// All context documents and tools are always provided to the agent when prompted.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::{completion::Prompt, providers::openai};
|
|
||||||
///
|
|
||||||
/// let openai = openai::Client::from_env();
|
|
||||||
///
|
|
||||||
/// let comedian_agent = openai
|
|
||||||
/// .agent("gpt-4o")
|
|
||||||
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
|
||||||
/// .temperature(0.9)
|
|
||||||
/// .build();
|
|
||||||
///
|
|
||||||
/// let response = comedian_agent.prompt("Entertain me!")
|
|
||||||
/// .await
|
|
||||||
/// .expect("Failed to prompt the agent");
|
|
||||||
/// ```
|
|
||||||
pub struct Agent<M: CompletionModel> {
|
|
||||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
|
||||||
model: M,
|
|
||||||
/// System prompt
|
|
||||||
preamble: String,
|
|
||||||
/// Context documents always available to the agent
|
|
||||||
static_context: Vec<Document>,
|
|
||||||
/// Tools that are always available to the agent (identified by their name)
|
|
||||||
static_tools: Vec<String>,
|
|
||||||
/// Temperature of the model
|
|
||||||
temperature: Option<f64>,
|
|
||||||
/// Maximum number of tokens for the completion
|
|
||||||
max_tokens: Option<u64>,
|
|
||||||
/// Additional parameters to be passed to the model
|
|
||||||
additional_params: Option<serde_json::Value>,
|
|
||||||
/// List of vector store, with the sample number
|
|
||||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Dynamic tools
|
|
||||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Actual tool implementations
|
|
||||||
pub tools: ToolSet,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
|
||||||
async fn completion(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
|
||||||
let prompt = prompt.into();
|
|
||||||
let rag_text = prompt.rag_text().clone();
|
|
||||||
|
|
||||||
let completion_request = self
|
|
||||||
.model
|
|
||||||
.completion_request(prompt)
|
|
||||||
.preamble(self.preamble.clone())
|
|
||||||
.messages(chat_history)
|
|
||||||
.temperature_opt(self.temperature)
|
|
||||||
.max_tokens_opt(self.max_tokens)
|
|
||||||
.additional_params_opt(self.additional_params.clone())
|
|
||||||
.documents(self.static_context.clone());
|
|
||||||
|
|
||||||
let agent = match &rag_text {
|
|
||||||
Some(text) => {
|
|
||||||
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
|
||||||
.then(|(num_sample, index)| async {
|
|
||||||
Ok::<_, VectorStoreError>(
|
|
||||||
index
|
|
||||||
.top_n(text, *num_sample)
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|(_, id, doc)| {
|
|
||||||
// Pretty print the document if possible for better readability
|
|
||||||
let text = serde_json::to_string_pretty(&doc)
|
|
||||||
.unwrap_or_else(|_| doc.to_string());
|
|
||||||
|
|
||||||
Document {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
additional_props: HashMap::new(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.try_fold(vec![], |mut acc, docs| async {
|
|
||||||
acc.extend(docs);
|
|
||||||
Ok(acc)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
|
||||||
|
|
||||||
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
|
||||||
.then(|(num_sample, index)| async {
|
|
||||||
Ok::<_, VectorStoreError>(
|
|
||||||
index
|
|
||||||
.top_n_ids(text, *num_sample)
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|(_, id)| id)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.try_fold(vec![], |mut acc, docs| async {
|
|
||||||
for doc in docs {
|
|
||||||
if let Some(tool) = self.tools.get(&doc) {
|
|
||||||
acc.push(tool.definition(text.into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(acc)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
|
||||||
|
|
||||||
let static_tools = stream::iter(self.static_tools.iter())
|
|
||||||
.filter_map(|toolname| async move {
|
|
||||||
if let Some(tool) = self.tools.get(toolname) {
|
|
||||||
Some(tool.definition(text.into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!(
|
|
||||||
"Tool implementation not found in toolset: {}",
|
|
||||||
toolname
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
completion_request
|
|
||||||
.documents(dynamic_context)
|
|
||||||
.tools([static_tools.clone(), dynamic_tools].concat())
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
let static_tools = stream::iter(self.static_tools.iter())
|
|
||||||
.filter_map(|toolname| async move {
|
|
||||||
if let Some(tool) = self.tools.get(toolname) {
|
|
||||||
// TODO: tool definitions should likely take an `Option<String>`
|
|
||||||
Some(tool.definition("".into()).await)
|
|
||||||
} else {
|
|
||||||
tracing::warn!(
|
|
||||||
"Tool implementation not found in toolset: {}",
|
|
||||||
toolname
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
completion_request.tools(static_tools)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(agent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> Prompt for Agent<M> {
|
|
||||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
|
||||||
self.chat(prompt, vec![]).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> Prompt for &Agent<M> {
|
|
||||||
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
|
|
||||||
self.chat(prompt, vec![]).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> Chat for Agent<M> {
|
|
||||||
async fn chat(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<String, PromptError> {
|
|
||||||
let resp = self.completion(prompt, chat_history).await?.send().await?;
|
|
||||||
|
|
||||||
// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
|
|
||||||
match resp.choice.first() {
|
|
||||||
AssistantContent::Text(text) => Ok(text.text.clone()),
|
|
||||||
AssistantContent::ToolCall(tool_call) => Ok(self
|
|
||||||
.tools
|
|
||||||
.call(
|
|
||||||
&tool_call.function.name,
|
|
||||||
tool_call.function.arguments.to_string(),
|
|
||||||
)
|
|
||||||
.await?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A builder for creating an agent
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use rig::{providers::openai, agent::AgentBuilder};
|
|
||||||
///
|
|
||||||
/// let openai = openai::Client::from_env();
|
|
||||||
///
|
|
||||||
/// let gpt4o = openai.completion_model("gpt-4o");
|
|
||||||
///
|
|
||||||
/// // Configure the agent
|
|
||||||
/// let agent = AgentBuilder::new(model)
|
|
||||||
/// .preamble("System prompt")
|
|
||||||
/// .context("Context document 1")
|
|
||||||
/// .context("Context document 2")
|
|
||||||
/// .tool(tool1)
|
|
||||||
/// .tool(tool2)
|
|
||||||
/// .temperature(0.8)
|
|
||||||
/// .additional_params(json!({"foo": "bar"}))
|
|
||||||
/// .build();
|
|
||||||
/// ```
|
|
||||||
pub struct AgentBuilder<M: CompletionModel> {
|
|
||||||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
|
||||||
model: M,
|
|
||||||
/// System prompt
|
|
||||||
preamble: Option<String>,
|
|
||||||
/// Context documents always available to the agent
|
|
||||||
static_context: Vec<Document>,
|
|
||||||
/// Tools that are always available to the agent (by name)
|
|
||||||
static_tools: Vec<String>,
|
|
||||||
/// Additional parameters to be passed to the model
|
|
||||||
additional_params: Option<serde_json::Value>,
|
|
||||||
/// Maximum number of tokens for the completion
|
|
||||||
max_tokens: Option<u64>,
|
|
||||||
/// List of vector store, with the sample number
|
|
||||||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Dynamic tools
|
|
||||||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
|
||||||
/// Temperature of the model
|
|
||||||
temperature: Option<f64>,
|
|
||||||
/// Actual tool implementations
|
|
||||||
tools: ToolSet,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: CompletionModel> AgentBuilder<M> {
|
|
||||||
pub fn new(model: M) -> Self {
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
preamble: None,
|
|
||||||
static_context: vec![],
|
|
||||||
static_tools: vec![],
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
additional_params: None,
|
|
||||||
dynamic_context: vec![],
|
|
||||||
dynamic_tools: vec![],
|
|
||||||
tools: ToolSet::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the system prompt
|
|
||||||
pub fn preamble(mut self, preamble: &str) -> Self {
|
|
||||||
self.preamble = Some(preamble.into());
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append to the preamble of the agent
|
|
||||||
pub fn append_preamble(mut self, doc: &str) -> Self {
|
|
||||||
self.preamble = Some(format!(
|
|
||||||
"{}\n{}",
|
|
||||||
self.preamble.unwrap_or_else(|| "".into()),
|
|
||||||
doc
|
|
||||||
));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a static context document to the agent
|
|
||||||
pub fn context(mut self, doc: &str) -> Self {
|
|
||||||
self.static_context.push(Document {
|
|
||||||
id: format!("static_doc_{}", self.static_context.len()),
|
|
||||||
text: doc.into(),
|
|
||||||
additional_props: HashMap::new(),
|
|
||||||
});
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a static tool to the agent
|
|
||||||
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
|
||||||
let toolname = tool.name();
|
|
||||||
self.tools.add_tool(tool);
|
|
||||||
self.static_tools.push(toolname);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add an MCP tool to the agent
|
|
||||||
#[cfg(feature = "mcp")]
|
|
||||||
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
|
||||||
mut self,
|
|
||||||
tool: mcp_core::types::Tool,
|
|
||||||
client: mcp_core::client::Client<T>,
|
|
||||||
) -> Self {
|
|
||||||
let toolname = tool.name.clone();
|
|
||||||
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
|
||||||
self.static_tools.push(toolname);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
|
||||||
/// dynamic context will be inserted in the request.
|
|
||||||
pub fn dynamic_context(
|
|
||||||
mut self,
|
|
||||||
sample: usize,
|
|
||||||
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
|
||||||
) -> Self {
|
|
||||||
self.dynamic_context
|
|
||||||
.push((sample, Box::new(dynamic_context)));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
|
||||||
/// dynamic toolset will be inserted in the request.
|
|
||||||
pub fn dynamic_tools(
|
|
||||||
mut self,
|
|
||||||
sample: usize,
|
|
||||||
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
|
||||||
toolset: ToolSet,
|
|
||||||
) -> Self {
|
|
||||||
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
|
||||||
self.tools.add_tools(toolset);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the temperature of the model
|
|
||||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
|
||||||
self.temperature = Some(temperature);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the maximum number of tokens for the completion
|
|
||||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
|
||||||
self.max_tokens = Some(max_tokens);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set additional parameters to be passed to the model
|
|
||||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
|
||||||
self.additional_params = Some(params);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the agent
|
|
||||||
pub fn build(self) -> Agent<M> {
|
|
||||||
Agent {
|
|
||||||
model: self.model,
|
|
||||||
preamble: self.preamble.unwrap_or_default(),
|
|
||||||
static_context: self.static_context,
|
|
||||||
static_tools: self.static_tools,
|
|
||||||
temperature: self.temperature,
|
|
||||||
max_tokens: self.max_tokens,
|
|
||||||
additional_params: self.additional_params,
|
|
||||||
dynamic_context: self.dynamic_context,
|
|
||||||
dynamic_tools: self.dynamic_tools,
|
|
||||||
tools: self.tools,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
|
||||||
async fn stream_completion(
|
|
||||||
&self,
|
|
||||||
prompt: impl Into<Message> + Send,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
|
||||||
// Reuse the existing completion implementation to build the request
|
|
||||||
// This ensures streaming and non-streaming use the same request building logic
|
|
||||||
self.completion(prompt, chat_history).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
|
||||||
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
|
||||||
self.stream_chat(prompt, vec![]).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
|
||||||
async fn stream_chat(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
chat_history: Vec<Message>,
|
|
||||||
) -> Result<StreamingResult, CompletionError> {
|
|
||||||
self.stream_completion(prompt, chat_history)
|
|
||||||
.await?
|
|
||||||
.stream()
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{CompletionModel, Document},
|
||||||
|
tool::{Tool, ToolSet},
|
||||||
|
vector_store::VectorStoreIndexDyn,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
use crate::tool::McpTool;
|
||||||
|
|
||||||
|
use super::Agent;
|
||||||
|
|
||||||
|
/// A builder for creating an agent
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::{providers::openai, agent::AgentBuilder};
|
||||||
|
///
|
||||||
|
/// let openai = openai::Client::from_env();
|
||||||
|
///
|
||||||
|
/// let gpt4o = openai.completion_model("gpt-4o");
|
||||||
|
///
|
||||||
|
/// // Configure the agent
|
||||||
|
/// let agent = AgentBuilder::new(model)
|
||||||
|
/// .preamble("System prompt")
|
||||||
|
/// .context("Context document 1")
|
||||||
|
/// .context("Context document 2")
|
||||||
|
/// .tool(tool1)
|
||||||
|
/// .tool(tool2)
|
||||||
|
/// .temperature(0.8)
|
||||||
|
/// .additional_params(json!({"foo": "bar"}))
|
||||||
|
/// .build();
|
||||||
|
/// ```
|
||||||
|
pub struct AgentBuilder<M: CompletionModel> {
|
||||||
|
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||||
|
model: M,
|
||||||
|
/// System prompt
|
||||||
|
preamble: Option<String>,
|
||||||
|
/// Context documents always available to the agent
|
||||||
|
static_context: Vec<Document>,
|
||||||
|
/// Tools that are always available to the agent (by name)
|
||||||
|
static_tools: Vec<String>,
|
||||||
|
/// Additional parameters to be passed to the model
|
||||||
|
additional_params: Option<serde_json::Value>,
|
||||||
|
/// Maximum number of tokens for the completion
|
||||||
|
max_tokens: Option<u64>,
|
||||||
|
/// List of vector store, with the sample number
|
||||||
|
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||||
|
/// Dynamic tools
|
||||||
|
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
|
||||||
|
/// Temperature of the model
|
||||||
|
temperature: Option<f64>,
|
||||||
|
/// Actual tool implementations
|
||||||
|
tools: ToolSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> AgentBuilder<M> {
|
||||||
|
pub fn new(model: M) -> Self {
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
preamble: None,
|
||||||
|
static_context: vec![],
|
||||||
|
static_tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
additional_params: None,
|
||||||
|
dynamic_context: vec![],
|
||||||
|
dynamic_tools: vec![],
|
||||||
|
tools: ToolSet::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the system prompt
|
||||||
|
pub fn preamble(mut self, preamble: &str) -> Self {
|
||||||
|
self.preamble = Some(preamble.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append to the preamble of the agent
|
||||||
|
pub fn append_preamble(mut self, doc: &str) -> Self {
|
||||||
|
self.preamble = Some(format!(
|
||||||
|
"{}\n{}",
|
||||||
|
self.preamble.unwrap_or_else(|| "".into()),
|
||||||
|
doc
|
||||||
|
));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a static context document to the agent
|
||||||
|
pub fn context(mut self, doc: &str) -> Self {
|
||||||
|
self.static_context.push(Document {
|
||||||
|
id: format!("static_doc_{}", self.static_context.len()),
|
||||||
|
text: doc.into(),
|
||||||
|
additional_props: HashMap::new(),
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a static tool to the agent
|
||||||
|
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
|
||||||
|
let toolname = tool.name();
|
||||||
|
self.tools.add_tool(tool);
|
||||||
|
self.static_tools.push(toolname);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an MCP tool to the agent
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
pub fn mcp_tool<T: mcp_core::transport::Transport>(
|
||||||
|
mut self,
|
||||||
|
tool: mcp_core::types::Tool,
|
||||||
|
client: mcp_core::client::Client<T>,
|
||||||
|
) -> Self {
|
||||||
|
let toolname = tool.name.clone();
|
||||||
|
self.tools.add_tool(McpTool::from_mcp_server(tool, client));
|
||||||
|
self.static_tools.push(toolname);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
|
||||||
|
/// dynamic context will be inserted in the request.
|
||||||
|
pub fn dynamic_context(
|
||||||
|
mut self,
|
||||||
|
sample: usize,
|
||||||
|
dynamic_context: impl VectorStoreIndexDyn + 'static,
|
||||||
|
) -> Self {
|
||||||
|
self.dynamic_context
|
||||||
|
.push((sample, Box::new(dynamic_context)));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
|
||||||
|
/// dynamic toolset will be inserted in the request.
|
||||||
|
pub fn dynamic_tools(
|
||||||
|
mut self,
|
||||||
|
sample: usize,
|
||||||
|
dynamic_tools: impl VectorStoreIndexDyn + 'static,
|
||||||
|
toolset: ToolSet,
|
||||||
|
) -> Self {
|
||||||
|
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
|
||||||
|
self.tools.add_tools(toolset);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the temperature of the model
|
||||||
|
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||||
|
self.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the maximum number of tokens for the completion
|
||||||
|
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||||
|
self.max_tokens = Some(max_tokens);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set additional parameters to be passed to the model
|
||||||
|
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||||
|
self.additional_params = Some(params);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the agent
|
||||||
|
pub fn build(self) -> Agent<M> {
|
||||||
|
Agent {
|
||||||
|
model: self.model,
|
||||||
|
preamble: self.preamble.unwrap_or_default(),
|
||||||
|
static_context: self.static_context,
|
||||||
|
static_tools: self.static_tools,
|
||||||
|
temperature: self.temperature,
|
||||||
|
max_tokens: self.max_tokens,
|
||||||
|
additional_params: self.additional_params,
|
||||||
|
dynamic_context: self.dynamic_context,
|
||||||
|
dynamic_tools: self.dynamic_tools,
|
||||||
|
tools: self.tools,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,244 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use futures::{stream, StreamExt, TryStreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{
|
||||||
|
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
|
||||||
|
Message, Prompt, PromptError,
|
||||||
|
},
|
||||||
|
streaming::{
|
||||||
|
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
|
||||||
|
StreamingResult,
|
||||||
|
},
|
||||||
|
tool::ToolSet,
|
||||||
|
vector_store::VectorStoreError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::prompt_request::PromptRequest;
|
||||||
|
|
||||||
|
/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
|
||||||
|
/// (i.e.: system prompt) and a static set of context documents and tools.
|
||||||
|
/// All context documents and tools are always provided to the agent when prompted.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use rig::{completion::Prompt, providers::openai};
|
||||||
|
///
|
||||||
|
/// let openai = openai::Client::from_env();
|
||||||
|
///
|
||||||
|
/// let comedian_agent = openai
|
||||||
|
/// .agent("gpt-4o")
|
||||||
|
/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||||
|
/// .temperature(0.9)
|
||||||
|
/// .build();
|
||||||
|
///
|
||||||
|
/// let response = comedian_agent.prompt("Entertain me!")
|
||||||
|
/// .await
|
||||||
|
/// .expect("Failed to prompt the agent");
|
||||||
|
/// ```
|
||||||
|
pub struct Agent<M: CompletionModel> {
|
||||||
|
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
|
||||||
|
pub model: M,
|
||||||
|
/// System prompt
|
||||||
|
pub preamble: String,
|
||||||
|
/// Context documents always available to the agent
|
||||||
|
pub static_context: Vec<Document>,
|
||||||
|
/// Tools that are always available to the agent (identified by their name)
|
||||||
|
pub static_tools: Vec<String>,
|
||||||
|
/// Temperature of the model
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
/// Maximum number of tokens for the completion
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
/// Additional parameters to be passed to the model
|
||||||
|
pub additional_params: Option<serde_json::Value>,
|
||||||
|
/// List of vector store, with the sample number
|
||||||
|
pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||||
|
/// Dynamic tools
|
||||||
|
pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
|
||||||
|
/// Actual tool implementations
|
||||||
|
pub tools: ToolSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> Completion<M> for Agent<M> {
|
||||||
|
async fn completion(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||||
|
let prompt = prompt.into();
|
||||||
|
let rag_text = prompt.rag_text().clone();
|
||||||
|
|
||||||
|
let completion_request = self
|
||||||
|
.model
|
||||||
|
.completion_request(prompt)
|
||||||
|
.preamble(self.preamble.clone())
|
||||||
|
.messages(chat_history)
|
||||||
|
.temperature_opt(self.temperature)
|
||||||
|
.max_tokens_opt(self.max_tokens)
|
||||||
|
.additional_params_opt(self.additional_params.clone())
|
||||||
|
.documents(self.static_context.clone());
|
||||||
|
|
||||||
|
let agent = match &rag_text {
|
||||||
|
Some(text) => {
|
||||||
|
let dynamic_context = stream::iter(self.dynamic_context.iter())
|
||||||
|
.then(|(num_sample, index)| async {
|
||||||
|
Ok::<_, VectorStoreError>(
|
||||||
|
index
|
||||||
|
.top_n(text, *num_sample)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, id, doc)| {
|
||||||
|
// Pretty print the document if possible for better readability
|
||||||
|
let text = serde_json::to_string_pretty(&doc)
|
||||||
|
.unwrap_or_else(|_| doc.to_string());
|
||||||
|
|
||||||
|
Document {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
additional_props: HashMap::new(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.try_fold(vec![], |mut acc, docs| async {
|
||||||
|
acc.extend(docs);
|
||||||
|
Ok(acc)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
|
||||||
|
.then(|(num_sample, index)| async {
|
||||||
|
Ok::<_, VectorStoreError>(
|
||||||
|
index
|
||||||
|
.top_n_ids(text, *num_sample)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_, id)| id)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.try_fold(vec![], |mut acc, docs| async {
|
||||||
|
for doc in docs {
|
||||||
|
if let Some(tool) = self.tools.get(&doc) {
|
||||||
|
acc.push(tool.definition(text.into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Tool implementation not found in toolset: {}", doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(acc)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
let static_tools = stream::iter(self.static_tools.iter())
|
||||||
|
.filter_map(|toolname| async move {
|
||||||
|
if let Some(tool) = self.tools.get(toolname) {
|
||||||
|
Some(tool.definition(text.into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Tool implementation not found in toolset: {}",
|
||||||
|
toolname
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
completion_request
|
||||||
|
.documents(dynamic_context)
|
||||||
|
.tools([static_tools.clone(), dynamic_tools].concat())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let static_tools = stream::iter(self.static_tools.iter())
|
||||||
|
.filter_map(|toolname| async move {
|
||||||
|
if let Some(tool) = self.tools.get(toolname) {
|
||||||
|
// TODO: tool definitions should likely take an `Option<String>`
|
||||||
|
Some(tool.definition("".into()).await)
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Tool implementation not found in toolset: {}",
|
||||||
|
toolname
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
completion_request.tools(static_tools)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(agent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
|
||||||
|
// `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
|
||||||
|
// `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
|
||||||
|
//
|
||||||
|
// References:
|
||||||
|
// - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Prompt for Agent<M> {
|
||||||
|
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||||
|
PromptRequest::new(self, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Prompt for &Agent<M> {
|
||||||
|
fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
|
||||||
|
PromptRequest::new(*self, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
|
impl<M: CompletionModel> Chat for Agent<M> {
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<String, PromptError> {
|
||||||
|
let mut cloned_history = chat_history.clone();
|
||||||
|
PromptRequest::new(self, prompt)
|
||||||
|
.with_history(&mut cloned_history)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
|
||||||
|
async fn stream_completion(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<Message> + Send,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
|
||||||
|
// Reuse the existing completion implementation to build the request
|
||||||
|
// This ensures streaming and non-streaming use the same request building logic
|
||||||
|
self.completion(prompt, chat_history).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
|
||||||
|
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
|
||||||
|
self.stream_chat(prompt, vec![]).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
|
||||||
|
async fn stream_chat(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
) -> Result<StreamingResult, CompletionError> {
|
||||||
|
self.stream_completion(prompt, chat_history)
|
||||||
|
.await?
|
||||||
|
.stream()
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
//! This module contains the implementation of the [Agent] struct and its builder.
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
|
||||||
|
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
|
||||||
|
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct is highly configurable, allowing the user to define anything from
|
||||||
|
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
|
||||||
|
//! context documents and tools.
|
||||||
|
//!
|
||||||
|
//! The [Agent] struct implements the [crate::completion::Completion] and [crate::completion::Prompt] traits,
|
||||||
|
//! allowing it to be used for generating completions responses and prompts. The [Agent] struct also
|
||||||
|
//! implements the [crate::completion::Chat] trait, which allows it to be used for generating chat completions.
|
||||||
|
//!
|
||||||
|
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
|
||||||
|
//! It allows configuring the model, preamble, context documents, tools, temperature, and additional parameters
|
||||||
|
//! before building the agent.
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//! ```rust
|
||||||
|
//! use rig::{
|
||||||
|
//! completion::{Chat, Completion, Prompt},
|
||||||
|
//! providers::openai,
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! let openai = openai::Client::from_env();
|
||||||
|
//!
|
||||||
|
//! // Configure the agent
|
||||||
|
//! let agent = openai.agent("gpt-4o")
|
||||||
|
//! .preamble("System prompt")
|
||||||
|
//! .context("Context document 1")
|
||||||
|
//! .context("Context document 2")
|
||||||
|
//! .tool(tool1)
|
||||||
|
//! .tool(tool2)
|
||||||
|
//! .temperature(0.8)
|
||||||
|
//! .additional_params(json!({"foo": "bar"}))
|
||||||
|
//! .build();
|
||||||
|
//!
|
||||||
|
//! // Use the agent for completions and prompts
|
||||||
|
//! // Generate a chat completion response from a prompt and chat history
|
||||||
|
//! let chat_response = agent.chat("Prompt", chat_history)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to chat with Agent");
|
||||||
|
//!
|
||||||
|
//! // Generate a prompt completion response from a simple prompt
|
||||||
|
//! let chat_response = agent.prompt("Prompt")
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to prompt the Agent");
|
||||||
|
//!
|
||||||
|
//! // Generate a completion request builder from a prompt and chat history. The builder
|
||||||
|
//! // will contain the agent's configuration (i.e.: preamble, context documents, tools,
|
||||||
|
//! // model parameters, etc.), but these can be overwritten.
|
||||||
|
//! let completion_req_builder = agent.completion("Prompt", chat_history)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to create completion request builder");
|
||||||
|
//!
|
||||||
|
//! let response = completion_req_builder
|
||||||
|
//! .temperature(0.9) // Overwrite the agent's temperature
|
||||||
|
//! .send()
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to send completion request");
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! RAG Agent example
|
||||||
|
//! ```rust
|
||||||
|
//! use rig::{
|
||||||
|
//! completion::Prompt,
|
||||||
|
//! embeddings::EmbeddingsBuilder,
|
||||||
|
//! providers::openai,
|
||||||
|
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! // Initialize OpenAI client
|
||||||
|
//! let openai = openai::Client::from_env();
|
||||||
|
//!
|
||||||
|
//! // Initialize OpenAI embedding model
|
||||||
|
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
|
||||||
|
//!
|
||||||
|
//! // Create vector store, compute embeddings and load them in the store
|
||||||
|
//! let mut vector_store = InMemoryVectorStore::default();
|
||||||
|
//!
|
||||||
|
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||||
|
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
|
||||||
|
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
|
||||||
|
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
|
||||||
|
//! .build()
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to build embeddings");
|
||||||
|
//!
|
||||||
|
//! vector_store.add_documents(embeddings)
|
||||||
|
//! .await
|
||||||
|
//! .expect("Failed to add documents");
|
||||||
|
//!
|
||||||
|
//! // Create vector store index
|
||||||
|
//! let index = vector_store.index(embedding_model);
|
||||||
|
//!
|
||||||
|
//! let agent = openai.agent(openai::GPT_4O)
|
||||||
|
//! .preamble("
|
||||||
|
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
|
||||||
|
//! You will find additional non-standard word definitions that could be useful below.
|
||||||
|
//! ")
|
||||||
|
//! .dynamic_context(1, index)
|
||||||
|
//! .build();
|
||||||
|
//!
|
||||||
|
//! // Prompt the agent and print the response
|
||||||
|
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
|
||||||
|
//! .expect("Failed to prompt the agent");
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
mod builder;
|
||||||
|
mod completion;
|
||||||
|
mod prompt_request;
|
||||||
|
|
||||||
|
pub use builder::AgentBuilder;
|
||||||
|
pub use completion::Agent;
|
||||||
|
pub use prompt_request::PromptRequest;
|
|
@ -0,0 +1,173 @@
|
||||||
|
use std::future::IntoFuture;
|
||||||
|
|
||||||
|
use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
|
||||||
|
message::{AssistantContent, UserContent},
|
||||||
|
tool::ToolSetError,
|
||||||
|
OneOrMany,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::Agent;
|
||||||
|
|
||||||
|
/// A builder for creating prompt requests with customizable options.
|
||||||
|
/// Uses generics to track which options have been set during the build process.
|
||||||
|
pub struct PromptRequest<'a, M: CompletionModel> {
|
||||||
|
/// The prompt message to send to the model
|
||||||
|
prompt: Message,
|
||||||
|
/// Optional chat history to include with the prompt
|
||||||
|
/// Note: chat history needs to outlive the agent as it might be used with other agents
|
||||||
|
chat_history: Option<&'a mut Vec<Message>>,
|
||||||
|
/// Maximum depth for multi-turn conversations (0 means no multi-turn)
|
||||||
|
max_depth: usize,
|
||||||
|
/// The agent to use for execution
|
||||||
|
agent: &'a Agent<M>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
|
||||||
|
/// Create a new PromptRequest with the given prompt and model
|
||||||
|
pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
|
||||||
|
Self {
|
||||||
|
prompt: prompt.into(),
|
||||||
|
chat_history: None,
|
||||||
|
max_depth: 0,
|
||||||
|
agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: CompletionModel> PromptRequest<'a, M> {
|
||||||
|
/// Set the maximum depth for multi-turn conversations
|
||||||
|
pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
|
||||||
|
PromptRequest {
|
||||||
|
prompt: self.prompt,
|
||||||
|
chat_history: self.chat_history,
|
||||||
|
max_depth: depth,
|
||||||
|
agent: self.agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add chat history to the prompt request
|
||||||
|
pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
|
||||||
|
PromptRequest {
|
||||||
|
prompt: self.prompt,
|
||||||
|
chat_history: Some(history),
|
||||||
|
max_depth: self.max_depth,
|
||||||
|
agent: self.agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
|
||||||
|
/// for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
|
||||||
|
/// directly via the associated type.
|
||||||
|
impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
|
||||||
|
type Output = Result<String, PromptError>;
|
||||||
|
type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
|
||||||
|
|
||||||
|
fn into_future(self) -> Self::IntoFuture {
|
||||||
|
self.send().boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: CompletionModel> PromptRequest<'_, M> {
|
||||||
|
async fn send(self) -> Result<String, PromptError> {
|
||||||
|
let agent = self.agent;
|
||||||
|
let mut prompt = self.prompt;
|
||||||
|
let chat_history = if let Some(history) = self.chat_history {
|
||||||
|
history
|
||||||
|
} else {
|
||||||
|
&mut Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut current_max_depth = 0;
|
||||||
|
// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
|
||||||
|
while current_max_depth <= self.max_depth + 1 {
|
||||||
|
current_max_depth += 1;
|
||||||
|
|
||||||
|
if self.max_depth > 1 {
|
||||||
|
tracing::info!(
|
||||||
|
"Current conversation depth: {}/{}",
|
||||||
|
current_max_depth,
|
||||||
|
self.max_depth
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = agent
|
||||||
|
.completion(prompt.clone(), chat_history.to_vec())
|
||||||
|
.await?
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
chat_history.push(prompt);
|
||||||
|
|
||||||
|
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
|
||||||
|
.choice
|
||||||
|
.iter()
|
||||||
|
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
|
||||||
|
|
||||||
|
chat_history.push(Message::Assistant {
|
||||||
|
content: resp.choice.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
if tool_calls.is_empty() {
|
||||||
|
let merged_texts = texts
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|content| {
|
||||||
|
if let AssistantContent::Text(text) = content {
|
||||||
|
Some(text.text.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
if self.max_depth > 1 {
|
||||||
|
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no tool calls, depth is not relevant, we can just return the merged text.
|
||||||
|
return Ok(merged_texts);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool_content = stream::iter(tool_calls)
|
||||||
|
.then(async |choice| {
|
||||||
|
if let AssistantContent::ToolCall(tool_call) = choice {
|
||||||
|
let output = agent
|
||||||
|
.tools
|
||||||
|
.call(
|
||||||
|
&tool_call.function.name,
|
||||||
|
tool_call.function.arguments.to_string(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(UserContent::tool_result(
|
||||||
|
tool_call.id.clone(),
|
||||||
|
OneOrMany::one(output.into()),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"This should never happen as we already filtered for `ToolCall`"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<Result<UserContent, ToolSetError>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
|
||||||
|
|
||||||
|
prompt = Message::User {
|
||||||
|
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reach here, we never resolved the final tool call. We need to do ... something.
|
||||||
|
Err(PromptError::MaxDepthError {
|
||||||
|
max_depth: self.max_depth,
|
||||||
|
chat_history: chat_history.clone(),
|
||||||
|
prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -229,6 +229,16 @@ impl Message {
|
||||||
content: OneOrMany::one(AssistantContent::text(text)),
|
content: OneOrMany::one(AssistantContent::text(text)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper constructor to make creating tool result messages easier.
|
||||||
|
pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
Message::User {
|
||||||
|
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
|
||||||
|
id: id.into(),
|
||||||
|
content: OneOrMany::one(ToolResultContent::text(content)),
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserContent {
|
impl UserContent {
|
||||||
|
@ -467,6 +477,12 @@ impl From<String> for Text {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&String> for Text {
|
||||||
|
fn from(text: &String) -> Self {
|
||||||
|
text.to_owned().into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&str> for Text {
|
impl From<&str> for Text {
|
||||||
fn from(text: &str) -> Self {
|
fn from(text: &str) -> Self {
|
||||||
text.to_owned().into()
|
text.to_owned().into()
|
||||||
|
@ -497,6 +513,14 @@ impl From<&str> for Message {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&String> for Message {
|
||||||
|
fn from(text: &String) -> Self {
|
||||||
|
Message::User {
|
||||||
|
content: OneOrMany::one(UserContent::Text(text.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Text> for Message {
|
impl From<Text> for Message {
|
||||||
fn from(text: Text) -> Self {
|
fn from(text: Text) -> Self {
|
||||||
Message::User {
|
Message::User {
|
||||||
|
|
|
@ -75,7 +75,7 @@ use crate::{
|
||||||
tool::ToolSetError,
|
tool::ToolSetError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::message::AssistantContent;
|
use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
|
||||||
|
|
||||||
// Errors
|
// Errors
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
@ -108,6 +108,13 @@ pub enum PromptError {
|
||||||
|
|
||||||
#[error("ToolCallError: {0}")]
|
#[error("ToolCallError: {0}")]
|
||||||
ToolError(#[from] ToolSetError),
|
ToolError(#[from] ToolSetError),
|
||||||
|
|
||||||
|
#[error("MaxDepthError: (reached limit: {max_depth})")]
|
||||||
|
MaxDepthError {
|
||||||
|
max_depth: usize,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
prompt: Message,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
@ -163,7 +170,7 @@ pub trait Prompt: Send + Sync {
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
prompt: impl Into<Message> + Send,
|
prompt: impl Into<Message> + Send,
|
||||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
|
/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
|
||||||
|
@ -180,7 +187,7 @@ pub trait Chat: Send + Sync {
|
||||||
&self,
|
&self,
|
||||||
prompt: impl Into<Message> + Send,
|
prompt: impl Into<Message> + Send,
|
||||||
chat_history: Vec<Message>,
|
chat_history: Vec<Message>,
|
||||||
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
|
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait defining a low-level LLM completion interface
|
/// Trait defining a low-level LLM completion interface
|
||||||
|
@ -236,12 +243,11 @@ pub trait CompletionModel: Clone + Send + Sync {
|
||||||
|
|
||||||
/// Struct representing a general completion request that can be sent to a completion model provider.
|
/// Struct representing a general completion request that can be sent to a completion model provider.
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
/// The prompt to be sent to the completion model provider
|
|
||||||
pub prompt: Message,
|
|
||||||
/// The preamble to be sent to the completion model provider
|
/// The preamble to be sent to the completion model provider
|
||||||
pub preamble: Option<String>,
|
pub preamble: Option<String>,
|
||||||
/// The chat history to be sent to the completion model provider
|
/// The chat history to be sent to the completion model provider
|
||||||
pub chat_history: Vec<Message>,
|
/// The very last message will always be the prompt (hense why there is *always* one)
|
||||||
|
pub chat_history: OneOrMany<Message>,
|
||||||
/// The documents to be sent to the completion model provider
|
/// The documents to be sent to the completion model provider
|
||||||
pub documents: Vec<Document>,
|
pub documents: Vec<Document>,
|
||||||
/// The tools to be sent to the completion model provider
|
/// The tools to be sent to the completion model provider
|
||||||
|
@ -255,23 +261,33 @@ pub struct CompletionRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionRequest {
|
impl CompletionRequest {
|
||||||
pub fn prompt_with_context(&self) -> Message {
|
/// Returns documents normalized into a message (if any).
|
||||||
let mut new_prompt = self.prompt.clone();
|
/// Most providers do not accept documents directly as input, so it needs to convert into a
|
||||||
if let Message::User { ref mut content } = new_prompt {
|
/// `Message` so that it can be incorperated into `chat_history` as a
|
||||||
if !self.documents.is_empty() {
|
pub fn normalized_documents(&self) -> Option<Message> {
|
||||||
let attachments = self
|
if self.documents.is_empty() {
|
||||||
.documents
|
return None;
|
||||||
.iter()
|
|
||||||
.map(|doc| doc.to_string())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("");
|
|
||||||
let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
|
|
||||||
let mut new_content = vec![UserContent::text(formatted_content)];
|
|
||||||
new_content.extend(content.clone());
|
|
||||||
*content = OneOrMany::many(new_content).expect("This has more than 1 item");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
new_prompt
|
|
||||||
|
// Most providers will convert documents into a text unless it can handle document messages.
|
||||||
|
// We use `UserContent::document` for those who handle it directly!
|
||||||
|
let messages = self
|
||||||
|
.documents
|
||||||
|
.iter()
|
||||||
|
.map(|doc| {
|
||||||
|
UserContent::document(
|
||||||
|
doc.to_string(),
|
||||||
|
// In the future, we can customize `Document` to pass these extra types through.
|
||||||
|
// Most providers ditch these but they might want to use them.
|
||||||
|
Some(ContentFormat::String),
|
||||||
|
Some(DocumentMediaType::TXT),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
Some(Message::User {
|
||||||
|
content: OneOrMany::many(messages).expect("There will be atleast one document"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -446,10 +462,12 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
|
||||||
|
|
||||||
/// Builds the completion request.
|
/// Builds the completion request.
|
||||||
pub fn build(self) -> CompletionRequest {
|
pub fn build(self) -> CompletionRequest {
|
||||||
|
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
|
||||||
|
.expect("There will always be atleast the prompt");
|
||||||
|
|
||||||
CompletionRequest {
|
CompletionRequest {
|
||||||
prompt: self.prompt,
|
|
||||||
preamble: self.preamble,
|
preamble: self.preamble,
|
||||||
chat_history: self.chat_history,
|
chat_history,
|
||||||
documents: self.documents,
|
documents: self.documents,
|
||||||
tools: self.tools,
|
tools: self.tools,
|
||||||
temperature: self.temperature,
|
temperature: self.temperature,
|
||||||
|
@ -475,7 +493,6 @@ impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::OneOrMany;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -513,7 +530,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prompt_with_context_with_documents() {
|
fn test_normalize_documents_with_documents() {
|
||||||
let doc1 = Document {
|
let doc1 = Document {
|
||||||
id: "doc1".to_string(),
|
id: "doc1".to_string(),
|
||||||
text: "Document 1 text.".to_string(),
|
text: "Document 1 text.".to_string(),
|
||||||
|
@ -527,9 +544,8 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
prompt: "What is the capital of France?".into(),
|
|
||||||
preamble: None,
|
preamble: None,
|
||||||
chat_history: Vec::new(),
|
chat_history: OneOrMany::one("What is the capital of France?".into()),
|
||||||
documents: vec![doc1, doc2],
|
documents: vec![doc1, doc2],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
|
@ -539,19 +555,35 @@ mod tests {
|
||||||
|
|
||||||
let expected = Message::User {
|
let expected = Message::User {
|
||||||
content: OneOrMany::many(vec![
|
content: OneOrMany::many(vec![
|
||||||
UserContent::text(concat!(
|
UserContent::document(
|
||||||
"<attachments>\n",
|
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
|
||||||
"<file id: doc1>\nDocument 1 text.\n</file>\n",
|
Some(ContentFormat::String),
|
||||||
"<file id: doc2>\nDocument 2 text.\n</file>\n",
|
Some(DocumentMediaType::TXT),
|
||||||
"</attachments>"
|
),
|
||||||
)),
|
UserContent::document(
|
||||||
UserContent::text("What is the capital of France?"),
|
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
|
||||||
|
Some(ContentFormat::String),
|
||||||
|
Some(DocumentMediaType::TXT),
|
||||||
|
),
|
||||||
])
|
])
|
||||||
.expect("This has more than 1 item"),
|
.expect("There will be at least one document"),
|
||||||
};
|
};
|
||||||
|
|
||||||
request.prompt_with_context();
|
assert_eq!(request.normalized_documents(), Some(expected));
|
||||||
|
}
|
||||||
|
|
||||||
assert_eq!(request.prompt_with_context(), expected);
|
#[test]
|
||||||
|
fn test_normalize_documents_without_documents() {
|
||||||
|
let request = CompletionRequest {
|
||||||
|
preamble: None,
|
||||||
|
chat_history: OneOrMany::one("What is the capital of France?".into()),
|
||||||
|
documents: Vec::new(),
|
||||||
|
tools: Vec::new(),
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
additional_params: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(request.normalized_documents(), None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,10 +36,13 @@ use serde_json::json;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::{Agent, AgentBuilder},
|
agent::{Agent, AgentBuilder},
|
||||||
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
|
completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
|
||||||
|
message::{AssistantContent, Message, ToolCall, ToolFunction},
|
||||||
tool::Tool,
|
tool::Tool,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const SUBMIT_TOOL_NAME: &str = "submit";
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum ExtractionError {
|
pub enum ExtractionError {
|
||||||
#[error("No data extracted")]
|
#[error("No data extracted")]
|
||||||
|
@ -48,8 +51,8 @@ pub enum ExtractionError {
|
||||||
#[error("Failed to deserialize the extracted data: {0}")]
|
#[error("Failed to deserialize the extracted data: {0}")]
|
||||||
DeserializationError(#[from] serde_json::Error),
|
DeserializationError(#[from] serde_json::Error),
|
||||||
|
|
||||||
#[error("PromptError: {0}")]
|
#[error("CompletionError: {0}")]
|
||||||
PromptError(#[from] PromptError),
|
CompletionError(#[from] CompletionError),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extractor for structured data from text
|
/// Extractor for structured data from text
|
||||||
|
@ -62,14 +65,43 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel>
|
||||||
where
|
where
|
||||||
M: Sync,
|
M: Sync,
|
||||||
{
|
{
|
||||||
pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
|
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
|
||||||
let summary = self.agent.prompt(text).await?;
|
let response = self.agent.completion(text, vec![]).await?.send().await?;
|
||||||
|
|
||||||
if summary.is_empty() {
|
let arguments = response
|
||||||
return Err(ExtractionError::NoData);
|
.choice
|
||||||
|
.into_iter()
|
||||||
|
// We filter tool calls to look for submit tool calls
|
||||||
|
.filter_map(|content| {
|
||||||
|
if let AssistantContent::ToolCall(ToolCall {
|
||||||
|
function: ToolFunction { arguments, name },
|
||||||
|
..
|
||||||
|
}) = content
|
||||||
|
{
|
||||||
|
if name == SUBMIT_TOOL_NAME {
|
||||||
|
Some(arguments)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if arguments.len() > 1 {
|
||||||
|
tracing::warn!(
|
||||||
|
"Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(serde_json::from_str(&summary)?)
|
let raw_data = if let Some(arg) = arguments.into_iter().next() {
|
||||||
|
arg
|
||||||
|
} else {
|
||||||
|
return Err(ExtractionError::NoData);
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(serde_json::from_value(raw_data)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +164,7 @@ struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
|
||||||
struct SubmitError;
|
struct SubmitError;
|
||||||
|
|
||||||
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
|
impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
|
||||||
const NAME: &'static str = "submit";
|
const NAME: &'static str = SUBMIT_TOOL_NAME;
|
||||||
type Error = SubmitError;
|
type Error = SubmitError;
|
||||||
type Args = T;
|
type Args = T;
|
||||||
type Output = T;
|
type Output = T;
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
use std::future::IntoFuture;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{self, CompletionModel},
|
completion::{self, CompletionModel},
|
||||||
extractor::{ExtractionError, Extractor},
|
extractor::{ExtractionError, Extractor},
|
||||||
|
message::Message,
|
||||||
vector_store,
|
vector_store,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -79,14 +82,14 @@ impl<P, In> Prompt<P, In> {
|
||||||
|
|
||||||
impl<P, In> Op for Prompt<P, In>
|
impl<P, In> Op for Prompt<P, In>
|
||||||
where
|
where
|
||||||
P: completion::Prompt,
|
P: completion::Prompt + Send + Sync,
|
||||||
In: Into<String> + Send + Sync,
|
In: Into<String> + Send + Sync,
|
||||||
{
|
{
|
||||||
type Input = In;
|
type Input = In;
|
||||||
type Output = Result<String, completion::PromptError>;
|
type Output = Result<String, completion::PromptError>;
|
||||||
|
|
||||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
|
||||||
self.prompt.prompt(input.into()).await
|
self.prompt.prompt(input.into()).into_future()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,13 +130,13 @@ impl<M, Input, Output> Op for Extract<M, Input, Output>
|
||||||
where
|
where
|
||||||
M: CompletionModel,
|
M: CompletionModel,
|
||||||
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
|
||||||
Input: Into<String> + Send + Sync,
|
Input: Into<Message> + Send + Sync,
|
||||||
{
|
{
|
||||||
type Input = Input;
|
type Input = Input;
|
||||||
type Output = Result<Output, ExtractionError>;
|
type Output = Result<Output, ExtractionError>;
|
||||||
|
|
||||||
async fn call(&self, input: Self::Input) -> Self::Output {
|
async fn call(&self, input: Self::Input) -> Self::Output {
|
||||||
self.extractor.extract(&input.into()).await
|
self.extractor.extract(input).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,6 +162,7 @@ pub mod tests {
|
||||||
pub struct MockModel;
|
pub struct MockModel;
|
||||||
|
|
||||||
impl Prompt for MockModel {
|
impl Prompt for MockModel {
|
||||||
|
#[allow(refining_impl_trait)]
|
||||||
async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
|
async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
|
||||||
let msg: message::Message = prompt.into();
|
let msg: message::Message = prompt.into();
|
||||||
let prompt = match msg {
|
let prompt = match msg {
|
||||||
|
|
|
@ -553,26 +553,20 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt_message: Message = completion_request
|
let mut full_history = vec![];
|
||||||
.prompt_with_context()
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
.try_into()
|
full_history.push(docs);
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
}
|
||||||
|
full_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
let mut messages = completion_request
|
let full_history = full_history
|
||||||
.chat_history
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| {
|
.map(Message::try_from)
|
||||||
message
|
|
||||||
.try_into()
|
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<Message>, _>>()?;
|
.collect::<Result<Vec<Message>, _>>()?;
|
||||||
|
|
||||||
messages.push(prompt_message);
|
|
||||||
|
|
||||||
let mut request = json!({
|
let mut request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": full_history,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"system": completion_request.preamble.unwrap_or("".to_string()),
|
"system": completion_request.preamble.unwrap_or("".to_string()),
|
||||||
});
|
});
|
||||||
|
|
|
@ -7,7 +7,6 @@ use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefin
|
||||||
use super::decoders::sse::from_response as sse_from_response;
|
use super::decoders::sse::from_response as sse_from_response;
|
||||||
use crate::completion::{CompletionError, CompletionRequest};
|
use crate::completion::{CompletionError, CompletionRequest};
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
use crate::message::MessageError;
|
|
||||||
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -90,26 +89,20 @@ impl StreamingCompletionModel for CompletionModel {
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt_message: Message = completion_request
|
let mut full_history = vec![];
|
||||||
.prompt_with_context()
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
.try_into()
|
full_history.push(docs);
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
|
}
|
||||||
|
full_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
let mut messages = completion_request
|
let full_history = full_history
|
||||||
.chat_history
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| {
|
.map(Message::try_from)
|
||||||
message
|
|
||||||
.try_into()
|
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<Message>, _>>()?;
|
.collect::<Result<Vec<Message>, _>>()?;
|
||||||
|
|
||||||
messages.push(prompt_message);
|
|
||||||
|
|
||||||
let mut request = json!({
|
let mut request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": full_history,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"system": completion_request.preamble.unwrap_or("".to_string()),
|
"system": completion_request.preamble.unwrap_or("".to_string()),
|
||||||
"stream": true,
|
"stream": true,
|
||||||
|
|
|
@ -480,16 +480,14 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<serde_json::Value, CompletionError> {
|
) -> Result<serde_json::Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
|
||||||
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
|
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
|
||||||
Some(preamble) => vec![openai::Message::system(preamble)],
|
Some(preamble) => vec![openai::Message::system(preamble)],
|
||||||
None => vec![],
|
None => vec![],
|
||||||
};
|
};
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
// Convert prompt to user message
|
let docs: Vec<openai::Message> = docs.try_into()?;
|
||||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
full_history.extend(docs);
|
||||||
|
}
|
||||||
// Convert existing chat history
|
|
||||||
let chat_history: Vec<openai::Message> = completion_request
|
let chat_history: Vec<openai::Message> = completion_request
|
||||||
.chat_history
|
.chat_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -499,9 +497,7 @@ impl CompletionModel {
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
full_history.extend(chat_history);
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
@ -786,6 +782,7 @@ mod azure_tests {
|
||||||
|
|
||||||
use crate::completion::CompletionModel;
|
use crate::completion::CompletionModel;
|
||||||
use crate::embeddings::EmbeddingModel;
|
use crate::embeddings::EmbeddingModel;
|
||||||
|
use crate::OneOrMany;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
|
@ -812,8 +809,7 @@ mod azure_tests {
|
||||||
let completion = model
|
let completion = model
|
||||||
.completion(CompletionRequest {
|
.completion(CompletionRequest {
|
||||||
preamble: Some("You are a helpful assistant.".to_string()),
|
preamble: Some("You are a helpful assistant.".to_string()),
|
||||||
chat_history: vec![],
|
chat_history: OneOrMany::one("Hello!".into()),
|
||||||
prompt: "Hello, world!".into(),
|
|
||||||
documents: vec![],
|
documents: vec![],
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
|
|
|
@ -440,29 +440,34 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: completion::CompletionRequest,
|
completion_request: completion::CompletionRequest,
|
||||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||||
let prompt = completion_request.prompt_with_context();
|
// Build up the order of messages (context, chat_history)
|
||||||
|
let mut partial_history = vec![];
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
|
partial_history.push(docs);
|
||||||
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
let mut messages: Vec<message::Message> =
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
if let Some(preamble) = completion_request.preamble {
|
let mut full_history: Vec<Message> = completion_request
|
||||||
vec![preamble.into()]
|
.preamble
|
||||||
} else {
|
.map_or_else(Vec::new, |preamble| {
|
||||||
vec![]
|
vec![Message::System { content: preamble }]
|
||||||
};
|
});
|
||||||
|
|
||||||
messages.extend(completion_request.chat_history);
|
// Convert and extend the rest of the history
|
||||||
messages.push(prompt);
|
full_history.extend(
|
||||||
|
partial_history
|
||||||
let messages: Vec<Message> = messages
|
.into_iter()
|
||||||
.into_iter()
|
.map(message::Message::try_into)
|
||||||
.map(|msg| msg.try_into())
|
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||||
.collect::<Result<Vec<Vec<_>>, _>>()?
|
.into_iter()
|
||||||
.into_iter()
|
.flatten()
|
||||||
.flatten()
|
.collect::<Vec<_>>(),
|
||||||
.collect();
|
);
|
||||||
|
|
||||||
let request = json!({
|
let request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": full_history,
|
||||||
"documents": completion_request.documents,
|
"documents": completion_request.documents,
|
||||||
"temperature": completion_request.temperature,
|
"temperature": completion_request.temperature,
|
||||||
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
"tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
|
||||||
|
|
|
@ -379,28 +379,28 @@ impl DeepSeekCompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<serde_json::Value, CompletionError> {
|
) -> Result<serde_json::Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history, prompt)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
None => vec![],
|
partial_history.push(docs);
|
||||||
};
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
let mut full_history: Vec<Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -398,6 +398,13 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
|
// Build up the order of messages (context, chat_history, prompt)
|
||||||
|
let mut partial_history = vec![];
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
|
partial_history.push(docs);
|
||||||
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Add preamble to chat history (if available)
|
// Add preamble to chat history (if available)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||||
Some(preamble) => vec![Message {
|
Some(preamble) => vec![Message {
|
||||||
|
@ -408,19 +415,13 @@ impl CompletionModel {
|
||||||
None => vec![],
|
None => vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Convert and extend the rest of the history
|
||||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
full_history.extend(
|
||||||
|
partial_history
|
||||||
// Convert existing chat history
|
.into_iter()
|
||||||
let chat_history: Vec<Message> = completion_request
|
.map(message::Message::try_into)
|
||||||
.chat_history
|
.collect::<Result<Vec<Message>, _>>()?,
|
||||||
.into_iter()
|
);
|
||||||
.map(|message| message.try_into())
|
|
||||||
.collect::<Result<Vec<Message>, _>>()?;
|
|
||||||
|
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.push(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -93,11 +93,10 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn create_request_body(
|
pub(crate) fn create_request_body(
|
||||||
mut completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<GenerateContentRequest, CompletionError> {
|
) -> Result<GenerateContentRequest, CompletionError> {
|
||||||
let mut full_history = Vec::new();
|
let mut full_history = Vec::new();
|
||||||
full_history.append(&mut completion_request.chat_history);
|
full_history.extend(completion_request.chat_history);
|
||||||
full_history.push(completion_request.prompt_with_context());
|
|
||||||
|
|
||||||
let additional_params = completion_request
|
let additional_params = completion_request
|
||||||
.additional_params
|
.additional_params
|
||||||
|
|
|
@ -279,28 +279,31 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history, prompt)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![Message {
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
role: "system".to_string(),
|
partial_history.push(docs);
|
||||||
content: Some(preamble.to_string()),
|
}
|
||||||
}],
|
partial_history.extend(completion_request.chat_history);
|
||||||
None => vec![],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
let mut full_history: Vec<Message> =
|
||||||
|
completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| {
|
||||||
|
vec![Message {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: Some(preamble),
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Message>, _>>()?;
|
.collect::<Result<Vec<Message>, _>>()?,
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.push(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -502,8 +502,10 @@ impl CompletionModel {
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
Some(preamble) => vec![Message::system(preamble)],
|
||||||
None => vec![],
|
None => vec![],
|
||||||
};
|
};
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
let docs: Vec<Message> = docs.try_into()?;
|
||||||
|
full_history.extend(docs);
|
||||||
|
}
|
||||||
|
|
||||||
let chat_history: Vec<Message> = completion_request
|
let chat_history: Vec<Message> = completion_request
|
||||||
.chat_history
|
.chat_history
|
||||||
|
@ -516,7 +518,6 @@ impl CompletionModel {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
full_history.extend(chat_history);
|
full_history.extend(chat_history);
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let model = self.client.sub_provider.model_identifier(&self.model);
|
let model = self.client.sub_provider.model_identifier(&self.model);
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||||
|
|
||||||
use crate::json_utils::merge_inplace;
|
use crate::json_utils::merge_inplace;
|
||||||
|
use crate::message;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::AgentBuilder,
|
agent::AgentBuilder,
|
||||||
|
@ -306,28 +307,28 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history, prompt)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
None => vec![],
|
partial_history.push(docs);
|
||||||
};
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
let mut full_history: Vec<Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = json!({
|
let request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
|
|
@ -238,24 +238,25 @@ impl CompletionModel {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add prompt
|
// Add docs
|
||||||
messages.push(match &completion_request.prompt {
|
if let Some(Message::User { content }) = completion_request.normalized_documents() {
|
||||||
Message::User { content } => {
|
let text = content
|
||||||
let text = content
|
.into_iter()
|
||||||
.iter()
|
.filter_map(|doc| match doc {
|
||||||
.map(|c| match c {
|
UserContent::Document(doc) => Some(doc.data),
|
||||||
UserContent::Text(text) => &text.text,
|
UserContent::Text(text) => Some(text.text),
|
||||||
_ => "",
|
|
||||||
})
|
// This should always be `Document`
|
||||||
.collect::<Vec<_>>()
|
_ => None,
|
||||||
.join("\n");
|
|
||||||
serde_json::json!({
|
|
||||||
"role": "user",
|
|
||||||
"content": text
|
|
||||||
})
|
})
|
||||||
}
|
.collect::<Vec<_>>()
|
||||||
_ => unreachable!(),
|
.join("\n");
|
||||||
});
|
|
||||||
|
messages.push(serde_json::json!({
|
||||||
|
"role": "user",
|
||||||
|
"content": text
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// Add chat history
|
// Add chat history
|
||||||
for msg in completion_request.chat_history {
|
for msg in completion_request.chat_history {
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::json_utils::merge;
|
use crate::json_utils::merge;
|
||||||
|
use crate::message;
|
||||||
use crate::providers::openai::send_compatible_streaming_request;
|
use crate::providers::openai::send_compatible_streaming_request;
|
||||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -141,28 +142,30 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history)
|
||||||
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![openai::Message::system(preamble)],
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
None => vec![],
|
partial_history.push(docs);
|
||||||
};
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
let mut full_history: Vec<openai::Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| {
|
||||||
|
vec![openai::Message::system(&preamble)]
|
||||||
|
});
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<openai::Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
|
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -325,8 +325,27 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
|
// Build up the order of messages (context, chat_history)
|
||||||
|
let mut partial_history = vec![];
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
|
partial_history.push(docs);
|
||||||
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
|
let mut full_history: Vec<Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||||
|
|
||||||
|
// Convert and extend the rest of the history
|
||||||
|
full_history.extend(
|
||||||
|
partial_history
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| msg.try_into())
|
||||||
|
.collect::<Result<Vec<Message>, _>>()?,
|
||||||
|
);
|
||||||
|
|
||||||
// Convert internal prompt into a provider Message
|
// Convert internal prompt into a provider Message
|
||||||
let prompt: Message = completion_request.prompt_with_context().try_into()?;
|
|
||||||
let options = if let Some(extra) = completion_request.additional_params {
|
let options = if let Some(extra) = completion_request.additional_params {
|
||||||
json_utils::merge(
|
json_utils::merge(
|
||||||
json!({ "temperature": completion_request.temperature }),
|
json!({ "temperature": completion_request.temperature }),
|
||||||
|
@ -336,16 +355,6 @@ impl CompletionModel {
|
||||||
json!({ "temperature": completion_request.temperature })
|
json!({ "temperature": completion_request.temperature })
|
||||||
};
|
};
|
||||||
|
|
||||||
// Chat mode: assemble full conversation history including preamble and chat history
|
|
||||||
let mut full_history = Vec::new();
|
|
||||||
if let Some(preamble) = completion_request.preamble {
|
|
||||||
full_history.push(Message::system(&preamble));
|
|
||||||
}
|
|
||||||
for msg in completion_request.chat_history.into_iter() {
|
|
||||||
full_history.push(Message::try_from(msg)?);
|
|
||||||
}
|
|
||||||
full_history.push(prompt);
|
|
||||||
|
|
||||||
let mut request_payload = json!({
|
let mut request_payload = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": full_history,
|
"messages": full_history,
|
||||||
|
|
|
@ -605,28 +605,28 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
None => vec![],
|
partial_history.push(docs);
|
||||||
};
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
let mut full_history: Vec<Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -269,8 +269,11 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
None => vec![],
|
None => vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Gather docs
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
|
let docs: Vec<Message> = docs.try_into()?;
|
||||||
|
full_history.extend(docs);
|
||||||
|
}
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert existing chat history
|
||||||
let chat_history: Vec<Message> = completion_request
|
let chat_history: Vec<Message> = completion_request
|
||||||
|
@ -284,7 +287,6 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
|
|
||||||
// Combine all messages into a single history
|
// Combine all messages into a single history
|
||||||
full_history.extend(chat_history);
|
full_history.extend(chat_history);
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = json!({
|
let request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
|
|
@ -204,39 +204,36 @@ impl CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add context documents to current prompt
|
// Build up the order of messages (context, chat_history, prompt)
|
||||||
let prompt_with_context = completion_request.prompt_with_context();
|
let mut partial_history = vec![];
|
||||||
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
// Add preamble to messages (if available)
|
partial_history.push(docs);
|
||||||
let mut messages: Vec<Message> = if let Some(preamble) = completion_request.preamble {
|
|
||||||
vec![Message {
|
|
||||||
role: Role::System,
|
|
||||||
content: preamble,
|
|
||||||
}]
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add chat history to messages
|
|
||||||
for message in completion_request.chat_history {
|
|
||||||
messages.push(
|
|
||||||
message
|
|
||||||
.try_into()
|
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
// Add user prompt to messages
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
messages.push(
|
let mut full_history: Vec<Message> =
|
||||||
prompt_with_context
|
completion_request
|
||||||
.try_into()
|
.preamble
|
||||||
.map_err(|e: MessageError| CompletionError::RequestError(e.into()))?,
|
.map_or_else(Vec::new, |preamble| {
|
||||||
|
vec![Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: preamble,
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert and extend the rest of the history
|
||||||
|
full_history.extend(
|
||||||
|
partial_history
|
||||||
|
.into_iter()
|
||||||
|
.map(message::Message::try_into)
|
||||||
|
.collect::<Result<Vec<Message>, _>>()?,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Compose request
|
// Compose request
|
||||||
let request = json!({
|
let request = json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": full_history,
|
||||||
"temperature": completion_request.temperature,
|
"temperature": completion_request.temperature,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -148,7 +148,10 @@ impl CompletionModel {
|
||||||
Some(preamble) => vec![openai::Message::system(preamble)],
|
Some(preamble) => vec![openai::Message::system(preamble)],
|
||||||
None => vec![],
|
None => vec![],
|
||||||
};
|
};
|
||||||
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
|
let docs: Vec<openai::Message> = docs.try_into()?;
|
||||||
|
full_history.extend(docs);
|
||||||
|
}
|
||||||
let chat_history: Vec<openai::Message> = completion_request
|
let chat_history: Vec<openai::Message> = completion_request
|
||||||
.chat_history
|
.chat_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -157,8 +160,9 @@ impl CompletionModel {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
full_history.extend(chat_history);
|
full_history.extend(chat_history);
|
||||||
full_history.extend(prompt);
|
|
||||||
let mut request = if completion_request.tools.is_empty() {
|
let mut request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
|
|
@ -10,7 +10,6 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::client::{xai_api_types::ApiResponse, Client};
|
use super::client::{xai_api_types::ApiResponse, Client};
|
||||||
use crate::completion::CompletionRequest;
|
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use xai_api_types::{CompletionResponse, ToolDefinition};
|
use xai_api_types::{CompletionResponse, ToolDefinition};
|
||||||
|
|
||||||
|
@ -30,22 +29,13 @@ pub struct CompletionModel {
|
||||||
impl CompletionModel {
|
impl CompletionModel {
|
||||||
pub(crate) fn create_completion_request(
|
pub(crate) fn create_completion_request(
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: completion::CompletionRequest,
|
||||||
) -> Result<Value, CompletionError> {
|
) -> Result<Value, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Convert documents into user message
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let docs: Option<Vec<Message>> = completion_request
|
||||||
Some(preamble) => {
|
.normalized_documents()
|
||||||
if preamble.is_empty() {
|
.map(|docs| docs.try_into())
|
||||||
vec![]
|
.transpose()?;
|
||||||
} else {
|
|
||||||
vec![Message::system(preamble)]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => vec![],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convert prompt to user message
|
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert existing chat history
|
||||||
let chat_history: Vec<Message> = completion_request
|
let chat_history: Vec<Message> = completion_request
|
||||||
|
@ -57,9 +47,19 @@ impl CompletionModel {
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Combine all messages into a single history
|
// Init full history with preamble (or empty if non-existant)
|
||||||
|
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||||
|
Some(preamble) => vec![Message::system(preamble)],
|
||||||
|
None => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Docs appear right after preamble, if they exist
|
||||||
|
if let Some(docs) = docs {
|
||||||
|
full_history.extend(docs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat history and prompt appear in the order they were provided
|
||||||
full_history.extend(chat_history);
|
full_history.extend(chat_history);
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let mut request = if completion_request.tools.is_empty() {
|
let mut request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
|
@ -15,6 +15,7 @@ use rig::agent::AgentBuilder;
|
||||||
use rig::completion::{CompletionError, CompletionRequest};
|
use rig::completion::{CompletionError, CompletionRequest};
|
||||||
use rig::embeddings::{EmbeddingError, EmbeddingsBuilder};
|
use rig::embeddings::{EmbeddingError, EmbeddingsBuilder};
|
||||||
use rig::extractor::ExtractorBuilder;
|
use rig::extractor::ExtractorBuilder;
|
||||||
|
use rig::message;
|
||||||
use rig::providers::openai::{self, Message};
|
use rig::providers::openai::{self, Message};
|
||||||
use rig::OneOrMany;
|
use rig::OneOrMany;
|
||||||
use rig::{completion, embeddings, Embed};
|
use rig::{completion, embeddings, Embed};
|
||||||
|
@ -470,14 +471,19 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
&self,
|
&self,
|
||||||
completion_request: CompletionRequest,
|
completion_request: CompletionRequest,
|
||||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||||
// Add preamble to chat history (if available)
|
// Build up the order of messages (context, chat_history)
|
||||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
let mut partial_history = vec![];
|
||||||
Some(preamble) => vec![Message::system(preamble)],
|
if let Some(docs) = completion_request.normalized_documents() {
|
||||||
None => vec![],
|
partial_history.push(docs);
|
||||||
};
|
}
|
||||||
|
partial_history.extend(completion_request.chat_history);
|
||||||
|
|
||||||
|
// Initialize full history with preamble (or empty if non-existent)
|
||||||
|
let mut full_history: Vec<Message> = completion_request
|
||||||
|
.preamble
|
||||||
|
.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
|
||||||
|
|
||||||
// Convert prompt to user message
|
// Convert prompt to user message
|
||||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
|
||||||
tracing::info!("Try to get on-chain system prompt");
|
tracing::info!("Try to get on-chain system prompt");
|
||||||
let eternal_ai_rpc = std::env::var("ETERNALAI_RPC_URL").unwrap_or_else(|_| "".to_string());
|
let eternal_ai_rpc = std::env::var("ETERNALAI_RPC_URL").unwrap_or_else(|_| "".to_string());
|
||||||
let eternal_ai_contract =
|
let eternal_ai_contract =
|
||||||
|
@ -515,19 +521,16 @@ impl completion::CompletionModel for CompletionModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert existing chat history
|
// Convert and extend the rest of the history
|
||||||
let chat_history: Vec<Message> = completion_request
|
full_history.extend(
|
||||||
.chat_history
|
partial_history
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| message.try_into())
|
.map(message::Message::try_into)
|
||||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
// Combine all messages into a single history
|
|
||||||
full_history.extend(chat_history);
|
|
||||||
full_history.extend(prompt);
|
|
||||||
|
|
||||||
let request = if completion_request.tools.is_empty() {
|
let request = if completion_request.tools.is_empty() {
|
||||||
json!({
|
json!({
|
||||||
|
|
Loading…
Reference in New Issue