From 01d5448921aaab81c489d877d87cb9e824d1ab5a Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Fri, 28 Mar 2025 17:15:27 +0000 Subject: [PATCH 01/11] feat: basic completion request service for tower integration --- Cargo.lock | 1 + rig-core/Cargo.toml | 1 + rig-core/src/lib.rs | 1 + rig-core/src/middlewares/completion.rs | 33 ++++++++++++++++++++++++++ rig-core/src/middlewares/mod.rs | 1 + 5 files changed, 37 insertions(+) create mode 100644 rig-core/src/middlewares/completion.rs create mode 100644 rig-core/src/middlewares/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 6b45006..3ad7b08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8612,6 +8612,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-test", + "tower 0.5.2", "tracing", "tracing-subscriber", "worker", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 56e173d..72e5d12 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -39,6 +39,7 @@ bytes = "1.9.0" async-stream = "0.3.6" mime_guess = { version = "2.0.5" } base64 = { version = "0.22.1" } +tower = "0.5.2" [dev-dependencies] diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 300c962..a13d184 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -91,6 +91,7 @@ pub mod extractor; pub mod image_generation; pub(crate) mod json_utils; pub mod loaders; +pub mod middlewares; pub mod one_or_many; pub mod pipeline; pub mod providers; diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs new file mode 100644 index 0000000..16ea04e --- /dev/null +++ b/rig-core/src/middlewares/completion.rs @@ -0,0 +1,33 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use tower::Service; + +use crate::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse}; + +pub struct CompletionService { + model: M, +} + +impl Service for CompletionService +where + M: CompletionModel + 'static, +{ + type Response = CompletionResponse; + type Error = CompletionError; + + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let model = self.model.clone(); + + Box::pin(async move { model.completion(req).await }) + } +} diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs new file mode 100644 index 0000000..c893d10 --- /dev/null +++ b/rig-core/src/middlewares/mod.rs @@ -0,0 +1 @@ +pub mod completion; From d990a99c05ba1376cd5f4f75d105fa9a74e7b013 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Fri, 28 Mar 2025 17:18:05 +0000 Subject: [PATCH 02/11] docs: add docstrings --- rig-core/src/middlewares/completion.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index 16ea04e..33520a1 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -8,7 +8,11 @@ use tower::Service; use crate::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse}; +/// A completion model as a Tower service. +/// +/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service. pub struct CompletionService { + /// The model itself. model: M, } From a6ba53be0a8da78c5c925163e7efa890bd9889e3 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Tue, 1 Apr 2025 12:18:12 +0100 Subject: [PATCH 03/11] feat: expand tool range --- rig-core/src/middlewares/completion.rs | 297 ++++++++++++++++++++++++- rig-core/src/middlewares/extractor.rs | 139 ++++++++++++ rig-core/src/middlewares/mod.rs | 19 ++ rig-core/src/middlewares/rag.rs | 54 +++++ rig-core/src/middlewares/tools.rs | 131 +++++++++++ 5 files changed, 638 insertions(+), 2 deletions(-) create mode 100644 rig-core/src/middlewares/extractor.rs create mode 100644 rig-core/src/middlewares/rag.rs create mode 100644 rig-core/src/middlewares/tools.rs diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index 33520a1..86aa972 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -4,9 +4,293 @@ use std::{ task::{Context, Poll}, }; -use tower::Service; +use serde::{Deserialize, Serialize}; +use tower::{Layer, Service}; -use crate::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse}; +use crate::{ + completion::{ + CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, + CompletionResponse, Document, ToolDefinition, + }, + message::{Message, ToolResultContent, UserContent}, + OneOrMany, +}; + +use super::rag::RagResult; + +pub struct CompletionLayer { + model: M, + preamble: Option, + chat_history: Vec, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl CompletionLayer +where + M: CompletionModel, +{ + pub fn builder(model: M) -> CompletionLayerBuilder { + CompletionLayerBuilder::new(model) + } +} + +#[derive(Default)] +pub struct CompletionLayerBuilder { + model: M, + preamble: Option, + chat_history: Vec, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl CompletionLayerBuilder +where + M: CompletionModel, +{ + pub fn new(model: M) -> Self { + Self { + model, + preamble: None, + chat_history: vec![], + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + additional_params: None, + } + } + + pub fn preamble(mut self, preamble: String) -> Self { + self.preamble = Some(preamble); + + self + } + + pub fn preamble_opt(mut self, preamble: Option) -> Self { + self.preamble = preamble; + + self + } + + pub fn chat_history(mut self, chat_history: Vec) -> Self { + self.chat_history = chat_history; + + self + } + + pub fn documents(mut self, documents: Vec) -> Self { + self.documents = documents; + + self + } + + pub fn tools(mut self, tools: Vec) -> Self { + self.tools = tools; + + self + } + + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + + self + } + + pub fn temperature_opt(mut self, temperature: Option) -> Self { + self.temperature = temperature; + + self + } + + pub fn max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + + self + } + + pub fn max_tokens_opt(mut self, max_tokens: Option) -> Self { + self.max_tokens = max_tokens; + + self + } + + pub fn additional_params(mut self, params: serde_json::Value) -> Self { + self.additional_params = Some(params); + + self + } + + pub fn additional_params_opt(mut self, params: Option) -> Self { + self.additional_params = params; + + self + } + + pub fn build(self) -> CompletionLayer { + CompletionLayer { + model: self.model, + preamble: self.preamble, + chat_history: self.chat_history, + documents: self.documents, + tools: self.tools, + temperature: self.temperature, + max_tokens: self.max_tokens, + additional_params: self.additional_params, + } + } +} + +impl Layer for CompletionLayer +where + M: CompletionModel, +{ + type Service = CompletionLayerService; + fn layer(&self, inner: S) -> Self::Service { + CompletionLayerService { + inner, + model: self.model.clone(), + preamble: self.preamble.clone(), + chat_history: self.chat_history.clone(), + documents: self.documents.clone(), + tools: self.tools.clone(), + temperature: self.temperature, + max_tokens: self.max_tokens, + additional_params: self.additional_params.clone(), + } + } +} + +pub struct CompletionLayerService { + inner: S, + model: M, + preamble: Option, + chat_history: Vec, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl Service for CompletionLayerService +where + M: CompletionModel + 'static, + S: Service> + Clone + Send + 'static, + S::Future: Send, + T: Serialize + for<'a> Deserialize<'a>, +{ + type Response = CompletionResponse; + type Error = CompletionError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: String) -> Self::Future { + let mut inner = self.inner.clone(); + let model = self.model.clone(); + let preamble = self.preamble.clone(); + let chat_history = self.chat_history.clone(); + let documents = self.documents.clone(); + let temperature = self.temperature.clone(); + let tools = self.tools.clone(); + let max_tokens = self.max_tokens.clone(); + let additional_params = self.additional_params.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req.clone()).await else { + todo!("Handle error properly"); + }; + + let res: String = res + .into_iter() + .filter_map(|x| serde_json::to_string_pretty(&x.2).ok()) + .collect::>() + .join("\n"); + + let req = format!("{req}\n\nContext:\n{res}"); + + let mut req = CompletionRequestBuilder::new(model.clone(), req) + .documents(documents.clone()) + .tools(tools.clone()) + .messages(chat_history.clone()) + .temperature_opt(temperature.clone()) + .max_tokens_opt(max_tokens.clone()) + .additional_params_opt(additional_params.clone()); + + if let Some(preamble) = preamble.clone() { + req = req.preamble(preamble); + } + + let req = req.build(); + + model.completion(req).await + }) + } +} + +impl Service for CompletionLayerService +where + M: CompletionModel + 'static, + S: Service, String, ToolResultContent)> + + Clone + + Send + + 'static, + S::Future: Send, +{ + type Response = CompletionResponse; + type Error = CompletionError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let model = self.model.clone(); + let preamble = self.preamble.clone(); + let documents = self.documents.clone(); + let temperature = self.temperature.clone(); + let tools = self.tools.clone(); + let max_tokens = self.max_tokens.clone(); + let additional_params = self.additional_params.clone(); + + Box::pin(async move { + let Ok((messages, id, tool_content)) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let tool_result_message = Message::User { + content: OneOrMany::one(UserContent::tool_result(id, OneOrMany::one(tool_content))), + }; + + let mut req = CompletionRequestBuilder::new(model.clone(), tool_result_message) + .documents(documents.clone()) + .tools(tools.clone()) + .messages(messages) + .temperature_opt(temperature.clone()) + .max_tokens_opt(max_tokens.clone()) + .additional_params_opt(additional_params.clone()); + + if let Some(preamble) = preamble.clone() { + req = req.preamble(preamble); + } + + let req = req.build(); + + model.completion(req).await + }) + } +} /// A completion model as a Tower service. /// @@ -16,6 +300,15 @@ pub struct CompletionService { model: M, } +impl CompletionService +where + M: CompletionModel, +{ + pub fn new(model: M) -> Self { + Self { model } + } +} + impl Service for CompletionService where M: CompletionModel + 'static, diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs new file mode 100644 index 0000000..e2eddd0 --- /dev/null +++ b/rig-core/src/middlewares/extractor.rs @@ -0,0 +1,139 @@ +use std::{fmt::Display, future::Future, marker::PhantomData, pin::Pin, sync::Arc, task::Poll}; + +use mime_guess::mime::PLAIN; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tower::{Layer, Service}; + +use crate::{ + completion::{ + CompletionModel, CompletionRequest, CompletionRequestBuilder, CompletionResponse, + ToolDefinition, + }, + extractor::{ExtractionError, Extractor}, + message::{AssistantContent, Text}, + pipeline::agent_ops::prompt, + tool::Tool, +}; + +pub struct ExtractorLayer +where + M: CompletionModel, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, +{ + ext: Arc>, +} + +impl ExtractorLayer +where + M: CompletionModel, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, +{ + pub fn new(ext: Extractor) -> Self { + Self { ext: Arc::new(ext) } + } +} + +impl Layer for ExtractorLayer +where + M: CompletionModel + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, +{ + type Service = ExtractorLayerService; + fn layer(&self, inner: S) -> Self::Service { + ExtractorLayerService { + inner, + ext: Arc::clone(&self.ext), + } + } +} + +pub struct ExtractorLayerService +where + M: CompletionModel + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, +{ + inner: S, + ext: Arc>, +} + +impl Service for ExtractorLayerService +where + S: Service> + + Clone + + Send + + 'static, + S::Future: Send, + M: CompletionModel + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, +{ + type Response = T; + type Error = ExtractionError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let ext = self.ext.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Properly handle error"); + }; + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle errors properly"); + }; + + ext.extract(&text).await + }) + } +} + +pub struct ExtractorService +where + M: CompletionModel + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, +{ + ext: Arc>, +} + +impl ExtractorService +where + M: CompletionModel + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, +{ + pub fn new(ext: Extractor) -> Self { + Self { ext: Arc::new(ext) } + } +} + +impl Service

for ExtractorService +where + M: CompletionModel, + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, + P: Display + Send + 'static, +{ + type Response = T; + type Error = ExtractionError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: P) -> Self::Future { + let ext = self.ext.clone(); + + Box::pin(async move { ext.extract(&req.to_string()).await }) + } +} diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs index c893d10..534c421 100644 --- a/rig-core/src/middlewares/mod.rs +++ b/rig-core/src/middlewares/mod.rs @@ -1 +1,20 @@ +use thiserror::Error; + +use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError}; + pub mod completion; +pub mod extractor; +pub mod rag; +pub mod tools; + +#[derive(Debug, Error)] +pub enum ServiceError { + #[error("{0}")] + ExtractionError(#[from] ExtractionError), + #[error("{0}")] + CompletionError(#[from] CompletionError), + #[error("{0}")] + ToolSetError(#[from] ToolSetError), + #[error("Received incorrect message, must be {0}")] + InvalidMessageType(String), +} diff --git a/rig-core/src/middlewares/rag.rs b/rig-core/src/middlewares/rag.rs new file mode 100644 index 0000000..3c11309 --- /dev/null +++ b/rig-core/src/middlewares/rag.rs @@ -0,0 +1,54 @@ +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use serde::{Deserialize, Serialize}; +use tower::Service; + +use crate::vector_store::{VectorStoreError, VectorStoreIndex}; + +pub struct RagService { + vector_index: Arc, + num_results: usize, + _phantom: PhantomData, +} + +impl RagService +where + V: VectorStoreIndex, +{ + pub fn new(vector_index: V, num_results: usize) -> Self { + Self { + vector_index: Arc::new(vector_index), + num_results, + _phantom: PhantomData::default(), + } + } +} + +impl Service for RagService +where + V: VectorStoreIndex + 'static, + T: Serialize + for<'a> Deserialize<'a> + Send, +{ + type Response = RagResult; + type Error = VectorStoreError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: String) -> Self::Future { + let vector_index = self.vector_index.clone(); + let num_results = self.num_results.clone(); + + Box::pin(async move { vector_index.top_n(&req, num_results).await }) + } +} + +pub type RagResult = Vec<(f64, String, T)>; diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs new file mode 100644 index 0000000..5a5fc82 --- /dev/null +++ b/rig-core/src/middlewares/tools.rs @@ -0,0 +1,131 @@ +use crate::{ + completion::{CompletionRequest, CompletionResponse}, + message::{AssistantContent, Message, ToolResultContent, UserContent}, + providers::{self}, + tool::{ToolSet, ToolSetError}, + OneOrMany, +}; +use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; + +use tower::{Layer, Service}; + +pub struct ToolLayer { + tools: Arc, +} + +impl ToolLayer { + pub fn new(tools: ToolSet) -> Self { + Self { + tools: Arc::new(tools), + } + } +} + +impl Layer for ToolLayer { + type Service = ToolLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ToolLayerService { + inner, + tools: Arc::clone(&self.tools), + } + } +} + +pub struct ToolLayerService { + inner: S, + tools: Arc, +} + +impl Service for ToolLayerService +where + S: Service> + Clone + Send + 'static, + T: Send + 'static, + S::Future: Send, +{ + type Response = (Vec, String, ToolResultContent); + type Error = ToolSetError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let tools = self.tools.clone(); + let mut messages = req.chat_history.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let AssistantContent::ToolCall(tool_call) = res.choice.first() else { + todo!("Handle error properly"); + }; + + messages.push(Message::Assistant { + content: OneOrMany::one(AssistantContent::ToolCall(tool_call.clone())), + }); + + let Ok(res) = tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await + else { + todo!("Implement proper error handling"); + }; + + Ok((messages, tool_call.id, ToolResultContent::text(res))) + }) + } +} + +pub struct ToolService { + tools: Arc, +} + +type OpenAIResponse = CompletionResponse; + +impl Service for ToolService { + type Response = Message; + type Error = ToolSetError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: OpenAIResponse) -> Self::Future { + let tools = self.tools.clone(); + + Box::pin(async move { + let crate::message::AssistantContent::ToolCall(tool_call) = req.choice.first() else { + unimplemented!("handle error"); + }; + + let Ok(res) = tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await + else { + todo!("Implement proper error handling"); + }; + + Ok(Message::User { + content: OneOrMany::one(UserContent::tool_result( + tool_call.id, + OneOrMany::one(ToolResultContent::text(res)), + )), + }) + }) + } +} From c2ec18e45d23793815366c2019712a0ca4d1cf09 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Sun, 6 Apr 2025 20:35:20 +0100 Subject: [PATCH 04/11] refactor: all services should use CompletionRequest, add components --- Cargo.toml | 3 +- rig-core/src/middlewares/completion.rs | 83 +------- rig-core/src/middlewares/components.rs | 250 +++++++++++++++++++++++++ rig-core/src/middlewares/extractor.rs | 20 +- rig-core/src/middlewares/mod.rs | 3 +- rig-core/src/middlewares/rag.rs | 18 +- 6 files changed, 279 insertions(+), 98 deletions(-) create mode 100644 rig-core/src/middlewares/components.rs diff --git a/Cargo.toml b/Cargo.toml index a7309db..eb99783 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "rig-qdrant", "rig-core/rig-core-derive", "rig-sqlite", - "rig-eternalai", "rig-fastembed", + "rig-eternalai", + "rig-fastembed", "rig-surrealdb", ] diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index 86aa972..df6714a 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -4,7 +4,6 @@ use std::{ task::{Context, Poll}, }; -use serde::{Deserialize, Serialize}; use tower::{Layer, Service}; use crate::{ @@ -16,12 +15,9 @@ use crate::{ OneOrMany, }; -use super::rag::RagResult; - pub struct CompletionLayer { model: M, preamble: Option, - chat_history: Vec, documents: Vec, tools: Vec, temperature: Option, @@ -42,7 +38,6 @@ where pub struct CompletionLayerBuilder { model: M, preamble: Option, - chat_history: Vec, documents: Vec, tools: Vec, temperature: Option, @@ -58,7 +53,6 @@ where Self { model, preamble: None, - chat_history: vec![], documents: vec![], tools: vec![], temperature: None, @@ -79,12 +73,6 @@ where self } - pub fn chat_history(mut self, chat_history: Vec) -> Self { - self.chat_history = chat_history; - - self - } - pub fn documents(mut self, documents: Vec) -> Self { self.documents = documents; @@ -137,7 +125,7 @@ where CompletionLayer { model: self.model, preamble: self.preamble, - chat_history: self.chat_history, + documents: self.documents, tools: self.tools, temperature: self.temperature, @@ -157,7 +145,7 @@ where inner, model: self.model.clone(), preamble: self.preamble.clone(), - chat_history: self.chat_history.clone(), + documents: self.documents.clone(), tools: self.tools.clone(), temperature: self.temperature, @@ -171,7 +159,6 @@ pub struct CompletionLayerService { inner: S, model: M, preamble: Option, - chat_history: Vec, documents: Vec, tools: Vec, temperature: Option, @@ -179,64 +166,6 @@ pub struct CompletionLayerService { additional_params: Option, } -impl Service for CompletionLayerService -where - M: CompletionModel + 'static, - S: Service> + Clone + Send + 'static, - S::Future: Send, - T: Serialize + for<'a> Deserialize<'a>, -{ - type Response = CompletionResponse; - type Error = CompletionError; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: String) -> Self::Future { - let mut inner = self.inner.clone(); - let model = self.model.clone(); - let preamble = self.preamble.clone(); - let chat_history = self.chat_history.clone(); - let documents = self.documents.clone(); - let temperature = self.temperature.clone(); - let tools = self.tools.clone(); - let max_tokens = self.max_tokens.clone(); - let additional_params = self.additional_params.clone(); - - Box::pin(async move { - let Ok(res) = inner.call(req.clone()).await else { - todo!("Handle error properly"); - }; - - let res: String = res - .into_iter() - .filter_map(|x| serde_json::to_string_pretty(&x.2).ok()) - .collect::>() - .join("\n"); - - let req = format!("{req}\n\nContext:\n{res}"); - - let mut req = CompletionRequestBuilder::new(model.clone(), req) - .documents(documents.clone()) - .tools(tools.clone()) - .messages(chat_history.clone()) - .temperature_opt(temperature.clone()) - .max_tokens_opt(max_tokens.clone()) - .additional_params_opt(additional_params.clone()); - - if let Some(preamble) = preamble.clone() { - req = req.preamble(preamble); - } - - let req = req.build(); - - model.completion(req).await - }) - } -} - impl Service for CompletionLayerService where M: CompletionModel + 'static, @@ -259,9 +188,9 @@ where let model = self.model.clone(); let preamble = self.preamble.clone(); let documents = self.documents.clone(); - let temperature = self.temperature.clone(); + let temperature = self.temperature; let tools = self.tools.clone(); - let max_tokens = self.max_tokens.clone(); + let max_tokens = self.max_tokens; let additional_params = self.additional_params.clone(); Box::pin(async move { @@ -277,8 +206,8 @@ where .documents(documents.clone()) .tools(tools.clone()) .messages(messages) - .temperature_opt(temperature.clone()) - .max_tokens_opt(max_tokens.clone()) + .temperature_opt(temperature) + .max_tokens_opt(max_tokens) .additional_params_opt(additional_params.clone()); if let Some(preamble) = preamble.clone() { diff --git a/rig-core/src/middlewares/components.rs b/rig-core/src/middlewares/components.rs new file mode 100644 index 0000000..5b1d5c4 --- /dev/null +++ b/rig-core/src/middlewares/components.rs @@ -0,0 +1,250 @@ +use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, task::Poll}; + +use tower::{Layer, Service}; + +use crate::{ + completion::{CompletionRequest, CompletionResponse}, + message::{AssistantContent, Text}, +}; + +pub struct AwaitApprovalLayer { + integration: T, +} + +impl AwaitApprovalLayer +where + T: HumanInTheLoop + Clone, +{ + pub fn new(integration: T) -> Self { + Self { integration } + } + + pub fn with_predicate(self, predicate: R) -> AwaitApprovalLayerWithPredicate + where + D: Debug, + R: Fn() -> Pin + Send>> + Clone + Send + 'static, + { + AwaitApprovalLayerWithPredicate { + integration: self.integration, + predicate, + _t: PhantomData, + } + } +} + +impl Layer for AwaitApprovalLayer +where + T: HumanInTheLoop + Clone, +{ + type Service = AwaitApprovalLayerService; + + fn layer(&self, inner: S) -> Self::Service { + AwaitApprovalLayerService::new(inner, self.integration.clone()) + } +} + +pub struct AwaitApprovalLayerService { + inner: S, + integration: T, +} + +impl AwaitApprovalLayerService +where + T: HumanInTheLoop, +{ + pub fn new(inner: S, integration: T) -> Self { + Self { inner, integration } + } +} + +impl Service for AwaitApprovalLayerService +where + S: Service> + Clone + Send + 'static, + S::Future: Send, + Response: Clone + 'static + Send, + T: HumanInTheLoop + Clone + Send + 'static, +{ + type Response = CompletionResponse; + type Error = bool; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let await_approval_loop = self.integration.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle error properly"); + }; + + if await_approval_loop.send_message(&text).await.is_err() { + todo!("Handle error properly"); + } + + let Ok(bool) = await_approval_loop.await_approval().await else { + todo!("Handle error properly"); + }; + + if bool { + Ok(res) + } else { + todo!("Handle error properly - we should abort the pipeline here if the user wants to abort"); + } + }) + } +} + +pub struct AwaitApprovalLayerWithPredicate { + integration: T, + predicate: R, + _t: PhantomData

, +} + +impl Layer for AwaitApprovalLayerWithPredicate +where + T: HumanInTheLoop + Clone, + D: Debug, + R: Fn(&D) -> Pin + Send>> + Clone + Send + 'static, +{ + type Service = AwaitApprovalLayerServiceWithPredicate; + + fn layer(&self, inner: S) -> Self::Service { + let predicate = self.predicate.clone(); + AwaitApprovalLayerServiceWithPredicate::new( + inner, + self.integration.clone(), + predicate, + self._t, + ) + } +} + +pub struct AwaitApprovalLayerServiceWithPredicate { + inner: S, + integration: T, + predicate: R, + _t: PhantomData, +} + +impl AwaitApprovalLayerServiceWithPredicate +where + T: HumanInTheLoop, + R: Fn(&D) -> Pin + Send>> + Clone + Send + 'static, + D: Debug, +{ + pub fn new(inner: S, integration: T, predicate: R, _t: PhantomData) -> Self { + Self { + inner, + integration, + predicate, + _t, + } + } +} + +impl Service + for AwaitApprovalLayerServiceWithPredicate +where + R: Fn(&CompletionResponse) -> Pin + Send>> + + Clone + + Send + + 'static, + S: Service> + Clone + Send + 'static, + S::Future: Send, + Response: Clone + 'static + Send, + T: HumanInTheLoop + Clone + Send + 'static, +{ + type Response = CompletionResponse; + type Error = bool; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let await_approval_loop = self.integration.clone(); + let predicate = self.predicate.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + if predicate(&res).await { + return Ok(res); + } + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle error properly"); + }; + + if await_approval_loop.send_message(&text).await.is_err() { + todo!("Handle error properly"); + } + + let Ok(bool) = await_approval_loop.await_approval().await else { + todo!("Handle error properly"); + }; + + if bool { + Ok(res) + } else { + todo!("Handle error properly - we should abort the pipeline here if the user wants to abort"); + } + }) + } +} + +pub trait HumanInTheLoop { + fn send_message( + &self, + res: &str, + ) -> impl Future>> + Send; + fn await_approval( + &self, + ) -> impl Future>> + Send; +} + +pub struct Stdout; + +impl HumanInTheLoop for Stdout { + async fn send_message(&self, res: &str) -> Result<(), Box> { + print!( + "Current result: {res} + + Would you like to approve this step? [Y/n]" + ); + + Ok(()) + } + + async fn await_approval(&self) -> Result> { + let mut string = String::new(); + + loop { + std::io::stdin().read_line(&mut string).unwrap(); + + match string.to_lowercase().trim() { + "y" | "yes" => break Ok(true), + "n" | "no" => break Ok(false), + _ => println!("Please respond with 'y' or 'n'."), + } + } + } +} diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index e2eddd0..938e08d 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -1,19 +1,13 @@ -use std::{fmt::Display, future::Future, marker::PhantomData, pin::Pin, sync::Arc, task::Poll}; +use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; -use mime_guess::mime::PLAIN; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tower::{Layer, Service}; use crate::{ - completion::{ - CompletionModel, CompletionRequest, CompletionRequestBuilder, CompletionResponse, - ToolDefinition, - }, + completion::{CompletionModel, CompletionRequest, CompletionResponse}, extractor::{ExtractionError, Extractor}, message::{AssistantContent, Text}, - pipeline::agent_ops::prompt, - tool::Tool, }; pub struct ExtractorLayer @@ -114,11 +108,10 @@ where } } -impl Service

for ExtractorService +impl Service for ExtractorService where M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, - P: Display + Send + 'static, { type Response = T; type Error = ExtractionError; @@ -131,9 +124,12 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: P) -> Self::Future { + fn call(&mut self, req: CompletionRequest) -> Self::Future { let ext = self.ext.clone(); + let Some(req) = req.prompt.rag_text() else { + todo!("Handle error properly"); + }; - Box::pin(async move { ext.extract(&req.to_string()).await }) + Box::pin(async move { ext.extract(&req).await }) } } diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs index 534c421..f99abea 100644 --- a/rig-core/src/middlewares/mod.rs +++ b/rig-core/src/middlewares/mod.rs @@ -3,6 +3,7 @@ use thiserror::Error; use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError}; pub mod completion; +pub mod components; pub mod extractor; pub mod rag; pub mod tools; @@ -15,6 +16,4 @@ pub enum ServiceError { CompletionError(#[from] CompletionError), #[error("{0}")] ToolSetError(#[from] ToolSetError), - #[error("Received incorrect message, must be {0}")] - InvalidMessageType(String), } diff --git a/rig-core/src/middlewares/rag.rs b/rig-core/src/middlewares/rag.rs index 3c11309..ebf3dcd 100644 --- a/rig-core/src/middlewares/rag.rs +++ b/rig-core/src/middlewares/rag.rs @@ -9,7 +9,10 @@ use std::{ use serde::{Deserialize, Serialize}; use tower::Service; -use crate::vector_store::{VectorStoreError, VectorStoreIndex}; +use crate::{ + completion::CompletionRequest, + vector_store::{VectorStoreError, VectorStoreIndex}, +}; pub struct RagService { vector_index: Arc, @@ -25,12 +28,12 @@ where Self { vector_index: Arc::new(vector_index), num_results, - _phantom: PhantomData::default(), + _phantom: PhantomData, } } } -impl Service for RagService +impl Service for RagService where V: VectorStoreIndex + 'static, T: Serialize + for<'a> Deserialize<'a> + Send, @@ -43,11 +46,14 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: String) -> Self::Future { + fn call(&mut self, req: CompletionRequest) -> Self::Future { let vector_index = self.vector_index.clone(); - let num_results = self.num_results.clone(); + let num_results = self.num_results; + let Some(prompt) = req.prompt.rag_text() else { + todo!("Handle error properly"); + }; - Box::pin(async move { vector_index.top_n(&req, num_results).await }) + Box::pin(async move { vector_index.top_n(&prompt, num_results).await }) } } From eda2822beeb2a52ce44b2857a1dedb7bf3f5ef54 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Mon, 7 Apr 2025 15:48:12 +0100 Subject: [PATCH 05/11] chore: satisfy ci (clippy) --- rig-neo4j/src/lib.rs | 2 +- rig-neo4j/src/vector_index.rs | 2 +- rig-qdrant/src/lib.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-neo4j/src/lib.rs b/rig-neo4j/src/lib.rs index 3868643..607238e 100644 --- a/rig-neo4j/src/lib.rs +++ b/rig-neo4j/src/lib.rs @@ -277,7 +277,7 @@ impl Neo4jClient { /// ### Arguments /// * `index_name` - The name of the index to create. /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have - /// the label `:Movie`, pass "Movie" as the `node_label` parameter. + /// the label `:Movie`, pass "Movie" as the `node_label` parameter. /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding". /// pub async fn create_vector_index( diff --git a/rig-neo4j/src/vector_index.rs b/rig-neo4j/src/vector_index.rs index 8b43adf..6565426 100644 --- a/rig-neo4j/src/vector_index.rs +++ b/rig-neo4j/src/vector_index.rs @@ -211,7 +211,7 @@ impl VectorStoreIndex for Neo4jVec /// #### Generic Type Parameters /// /// - `T`: The type used to deserialize the result from the Neo4j query. - /// It must implement the `serde::Deserialize` trait. + /// It must implement the `serde::Deserialize` trait. /// /// #### Returns /// diff --git a/rig-qdrant/src/lib.rs b/rig-qdrant/src/lib.rs index f97bad1..038ec41 100644 --- a/rig-qdrant/src/lib.rs +++ b/rig-qdrant/src/lib.rs @@ -29,7 +29,7 @@ impl QdrantVectorStore { /// * `client` - Qdrant client instance /// * `model` - Embedding model instance /// * `query_params` - Search parameters for vector queries - /// Reference: + /// Reference: pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self { Self { client, From 65943130eca2911eea855ca11d9f16fd180bb270 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Mon, 7 Apr 2025 16:19:25 +0100 Subject: [PATCH 06/11] ci: satisfy clippy --- rig-qdrant/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-qdrant/src/lib.rs b/rig-qdrant/src/lib.rs index 038ec41..cb21627 100644 --- a/rig-qdrant/src/lib.rs +++ b/rig-qdrant/src/lib.rs @@ -29,7 +29,7 @@ impl QdrantVectorStore { /// * `client` - Qdrant client instance /// * `model` - Embedding model instance /// * `query_params` - Search parameters for vector queries - /// Reference: + /// Reference: pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self { Self { client, From ed0baca25223ab2e3023bc3f272df8f9c9829342 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Mon, 7 Apr 2025 16:21:57 +0100 Subject: [PATCH 07/11] chore: satisfy ci --- rig-neo4j/src/lib.rs | 2 +- rig-neo4j/src/vector_index.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-neo4j/src/lib.rs b/rig-neo4j/src/lib.rs index 607238e..f817180 100644 --- a/rig-neo4j/src/lib.rs +++ b/rig-neo4j/src/lib.rs @@ -277,7 +277,7 @@ impl Neo4jClient { /// ### Arguments /// * `index_name` - The name of the index to create. /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have - /// the label `:Movie`, pass "Movie" as the `node_label` parameter. + /// the label `:Movie`, pass "Movie" as the `node_label` parameter. /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding". /// pub async fn create_vector_index( diff --git a/rig-neo4j/src/vector_index.rs b/rig-neo4j/src/vector_index.rs index 6565426..c2e22d6 100644 --- a/rig-neo4j/src/vector_index.rs +++ b/rig-neo4j/src/vector_index.rs @@ -211,7 +211,7 @@ impl VectorStoreIndex for Neo4jVec /// #### Generic Type Parameters /// /// - `T`: The type used to deserialize the result from the Neo4j query. - /// It must implement the `serde::Deserialize` trait. + /// It must implement the `serde::Deserialize` trait. /// /// #### Returns /// From 49221ab84813d7a1dcc88901e3e5fe5c5e9bb957 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Wed, 9 Apr 2025 17:15:24 +0100 Subject: [PATCH 08/11] feat: a bunch of stuff --- rig-core/examples/tower.rs | 67 +++ rig-core/src/completion/request.rs | 1 + rig-core/src/middlewares/build_completions.rs | 532 ++++++++++++++++++ rig-core/src/middlewares/completion.rs | 1 + rig-core/src/middlewares/extractor.rs | 111 +--- rig-core/src/middlewares/mod.rs | 20 +- rig-core/src/middlewares/parallel.rs | 88 +++ rig-core/src/middlewares/rag.rs | 22 +- rig-core/src/middlewares/tools.rs | 49 +- 9 files changed, 748 insertions(+), 143 deletions(-) create mode 100644 rig-core/examples/tower.rs create mode 100644 rig-core/src/middlewares/build_completions.rs create mode 100644 rig-core/src/middlewares/parallel.rs diff --git a/rig-core/examples/tower.rs b/rig-core/examples/tower.rs new file mode 100644 index 0000000..0949b11 --- /dev/null +++ b/rig-core/examples/tower.rs @@ -0,0 +1,67 @@ +use rig::{ + completion::CompletionRequestBuilder, + middlewares::{ + completion::{CompletionLayer, CompletionService}, + tools::ToolLayer, + }, + providers::openai::Client, +}; +use tower::ServiceBuilder; + +#[tokio::main] +async fn main() { + let client = Client::from_env(); + let model = client.completion_model("gpt-4o"); + + let comp_layer = CompletionLayer::builder(model).build(); + let tool_layer = ToolLayer::new(vec![Add]); + let service = CompletionService::new(model); + + let service = ServiceBuilder::new() + .layer(comp_layer) + .layer(tool_layer) + .service(service); + + let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build(); + + let res = service.call(comp_request).await.unwrap(); + + println!("{res:?}"); +} + +#[derive(Deserialize, Serialize)] +struct Add; + +impl Tool for Add { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "add", + "description": "Add x and y together", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..60456b0 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync { } /// Struct representing a general completion request that can be sent to a completion model provider. +#[derive(Clone)] pub struct CompletionRequest { /// The prompt to be sent to the completion model provider pub prompt: Message, diff --git a/rig-core/src/middlewares/build_completions.rs b/rig-core/src/middlewares/build_completions.rs new file mode 100644 index 0000000..f2d626d --- /dev/null +++ b/rig-core/src/middlewares/build_completions.rs @@ -0,0 +1,532 @@ +use std::{future::Future, pin::Pin, task::Poll}; +use tower::{Layer, Service}; + +use crate::{ + completion::{ + CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition, + }, + message::Message, +}; + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuilding; + +impl Layer for FinishBuilding { + type Service = FinishBuildingService; + + fn layer(&self, inner: S) -> Self::Service { + FinishBuildingService { inner } + } +} + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Not intended to be used directly. Use [`FinishBuilding`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuildingService { + inner: S, +} + +impl Service<(M, Msg)> for FinishBuildingService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequest; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.build()) + }) + } +} + +/// A Tower layer to add documents to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayer { + documents: Vec, +} + +impl DocumentsLayer { + pub fn new(documents: Vec) -> Self { + Self { documents } + } +} + +impl Layer for DocumentsLayer { + type Service = DocumentsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + DocumentsLayerService { + inner, + documents: self.documents.clone(), + } + } +} + +/// A Tower service to add documents to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`DocumentsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayerService { + inner: S, + documents: Vec, +} + +impl Service<(M, Msg)> for DocumentsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let documents = self.documents.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.documents(documents)) + }) + } +} + +/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayer { + temperature: f64, +} + +impl TemperatureLayer { + pub fn new(temperature: f64) -> Self { + Self { temperature } + } +} + +impl Layer for TemperatureLayer { + type Service = TemperatureLayerService; + + fn layer(&self, inner: S) -> Self::Service { + TemperatureLayerService { + inner, + temperature: self.temperature, + } + } +} + +/// A Tower service to add a temperature value to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`TemperatureLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayerService { + inner: S, + temperature: f64, +} + +impl Service<(M, Msg)> for TemperatureLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let temperature = self.temperature; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.temperature(temperature)) + }) + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayer { + tools: Vec, +} + +impl ToolsLayer { + pub fn new(tools: Vec) -> Self { + Self { tools } + } +} + +impl Layer for ToolsLayer { + type Service = ToolsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ToolsLayerService { + inner, + tools: self.tools.clone(), + } + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ToolsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayerService { + inner: S, + tools: Vec, +} + +impl Service<(M, Msg)> for ToolsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let tools = self.tools.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.tools(tools)) + }) + } +} + +/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayer { + preamble: String, +} + +impl PreambleLayer { + pub fn new(preamble: String) -> Self { + Self { preamble } + } +} + +impl Layer for PreambleLayer { + type Service = PreambleLayerService; + + fn layer(&self, inner: S) -> Self::Service { + PreambleLayerService { + inner, + preamble: self.preamble.clone(), + } + } +} + +/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`PreambleLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayerService { + inner: S, + preamble: String, +} + +impl Service<(M, Msg)> for PreambleLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let preamble = self.preamble.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.preamble(preamble)) + }) + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayer { + additional_params: serde_json::Value, +} + +impl AdditionalParamsLayer { + pub fn new(additional_params: serde_json::Value) -> Self { + Self { additional_params } + } +} + +impl Layer for AdditionalParamsLayer { + type Service = AdditionalParamsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + AdditionalParamsLayerService { + inner, + additional_params: self.additional_params.clone(), + } + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayerService { + inner: S, + additional_params: serde_json::Value, +} + +impl Service<(M, Msg)> for AdditionalParamsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let additional_params = self.additional_params.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.additional_params(additional_params)) + }) + } +} + +/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayer { + max_tokens: u64, +} + +impl MaxTokensLayer { + pub fn new(max_tokens: u64) -> Self { + Self { max_tokens } + } +} + +impl Layer for MaxTokensLayer { + type Service = MaxTokensLayerService; + + fn layer(&self, inner: S) -> Self::Service { + MaxTokensLayerService { + inner, + max_tokens: self.max_tokens, + } + } +} + +/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`MaxTokensLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayerService { + inner: S, + max_tokens: u64, +} + +impl Service<(M, Msg)> for MaxTokensLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let max_tokens = self.max_tokens; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.max_tokens(max_tokens)) + }) + } +} + +/// A Tower layer to add a chat history to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayer { + chat_history: Vec, +} + +impl ChatHistoryLayer { + pub fn new(chat_history: Vec) -> Self { + Self { chat_history } + } +} + +impl Layer for ChatHistoryLayer { + type Service = ChatHistoryLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ChatHistoryLayerService { + inner, + chat_history: self.chat_history.clone(), + } + } +} + +/// A Tower service to add a chat history to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ChatHistoryLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayerService { + inner: S, + chat_history: Vec, +} + +impl Service<(M, Msg)> for ChatHistoryLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let chat_history = self.chat_history.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.messages(chat_history)) + }) + } +} + +/// A `tower` service for building a completion request. +/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top). +/// +/// Usage: +/// ```rust +/// let agent = rig::providers::openai::Client::from_env(); +/// let model = agent.completion_model("gpt-4o"); +/// +/// let service = tower::ServiceBuilder::new() +/// .layer(FinishBuilding) +/// .layer(TemperatureLayer::new(0.0)) +/// .layer(PreambleLayer::new("You are a helpful assistant")) +/// .service(CompletionRequestBuilderService); +/// +/// let request = service.call((model, "Hello world!".to_string())).await.unwrap(); +/// +/// let res = request.send().await.unwrap(); +/// +/// println!("{res:?}"); +/// ``` +pub struct CompletionRequestBuilderService; + +impl Service<(M, Msg)> for CompletionRequestBuilderService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future { + Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) }) + } +} diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index df6714a..1b216f2 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -224,6 +224,7 @@ where /// A completion model as a Tower service. /// /// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service. +#[derive(Clone)] pub struct CompletionService { /// The model itself. model: M, diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index 938e08d..528828c 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -1,68 +1,54 @@ -use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; - -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; use tower::{Layer, Service}; use crate::{ - completion::{CompletionModel, CompletionRequest, CompletionResponse}, - extractor::{ExtractionError, Extractor}, + completion::{CompletionRequest, CompletionResponse}, message::{AssistantContent, Text}, }; -pub struct ExtractorLayer -where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, -{ - ext: Arc>, +use super::ServiceError; + +pub struct ExtractorLayer { + _t: PhantomData, } -impl ExtractorLayer +impl ExtractorLayer where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, + T: for<'a> Deserialize<'a>, { - pub fn new(ext: Extractor) -> Self { - Self { ext: Arc::new(ext) } + pub fn new() -> Self { + Self { _t: PhantomData } } } -impl Layer for ExtractorLayer +impl Layer for ExtractorLayer where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, + T: for<'a> Deserialize<'a>, { - type Service = ExtractorLayerService; + type Service = ExtractorLayerService; fn layer(&self, inner: S) -> Self::Service { - ExtractorLayerService { - inner, - ext: Arc::clone(&self.ext), - } + ExtractorLayerService { inner, _t: self._t } } } -pub struct ExtractorLayerService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ +pub struct ExtractorLayerService { inner: S, - ext: Arc>, + _t: PhantomData, } -impl Service for ExtractorLayerService +impl Service for ExtractorLayerService where - S: Service> + S: Service, Error = ServiceError> + Clone + Send + 'static, S::Future: Send, - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, + F: 'static, + T: for<'a> Deserialize<'a> + 'static, { type Response = T; - type Error = ExtractionError; + type Error = ServiceError; type Future = Pin> + Send>>; fn poll_ready( @@ -73,63 +59,18 @@ where } fn call(&mut self, req: CompletionRequest) -> Self::Future { - let ext = self.ext.clone(); let mut inner = self.inner.clone(); Box::pin(async move { - let Ok(res) = inner.call(req).await else { - todo!("Properly handle error"); - }; + let res = inner.call(req).await?; let AssistantContent::Text(Text { text }) = res.choice.first() else { todo!("Handle errors properly"); }; - ext.extract(&text).await + let obj = serde_json::from_str::(&text)?; + + Ok(obj) }) } } - -pub struct ExtractorService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ - ext: Arc>, -} - -impl ExtractorService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ - pub fn new(ext: Extractor) -> Self { - Self { ext: Arc::new(ext) } - } -} - -impl Service for ExtractorService -where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, -{ - type Response = T; - type Error = ExtractionError; - type Future = Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: CompletionRequest) -> Self::Future { - let ext = self.ext.clone(); - let Some(req) = req.prompt.rag_text() else { - todo!("Handle error properly"); - }; - - Box::pin(async move { ext.extract(&req).await }) - } -} diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs index f99abea..2e02d6f 100644 --- a/rig-core/src/middlewares/mod.rs +++ b/rig-core/src/middlewares/mod.rs @@ -1,10 +1,15 @@ use thiserror::Error; -use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError}; +use crate::{ + completion::CompletionError, extractor::ExtractionError, tool::ToolSetError, + vector_store::VectorStoreError, +}; +pub mod build_completions; pub mod completion; pub mod components; pub mod extractor; +pub mod parallel; pub mod rag; pub mod tools; @@ -16,4 +21,17 @@ pub enum ServiceError { CompletionError(#[from] CompletionError), #[error("{0}")] ToolSetError(#[from] ToolSetError), + #[error("{0}")] + VectorStoreError(#[from] VectorStoreError), + #[error("Value required but was null: {0}")] + RequiredOptionNotFound(String), + #[error("{0}")] + Json(#[from] serde_json::Error), +} + +impl ServiceError { + pub fn required_option_not_exists>(val: S) -> Self { + let val: String = val.into(); + Self::RequiredOptionNotFound(val) + } } diff --git a/rig-core/src/middlewares/parallel.rs b/rig-core/src/middlewares/parallel.rs new file mode 100644 index 0000000..1d619be --- /dev/null +++ b/rig-core/src/middlewares/parallel.rs @@ -0,0 +1,88 @@ +use std::{future::Future, pin::Pin}; + +use tower::Service; + +use crate::completion::CompletionRequest; + +use super::ServiceError; + +pub struct Stackable { + pub inner: A, + pub outer: B, +} + +impl Stackable { + pub fn new(inner: A, outer: B) -> Self { + Self { inner, outer } + } + + pub fn take_values(self) -> (A, B) { + (self.inner, self.outer) + } +} + +#[derive(Clone)] +pub struct ParallelService { + first_service: S, + second_service: T, +} + +impl ParallelService +where + S: Service, + T: Service, +{ + pub fn new(first_service: S, second_service: T) -> Self { + Self { + first_service, + second_service, + } + } +} + +impl Service for ParallelService +where + S: Service + Clone + Send + 'static, + S::Future: Send, + S::Response: Send + 'static, + T: Service + Clone + Send + 'static, + T::Future: Send, + T::Response: Send + 'static, +{ + type Response = Stackable; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut first = self.first_service.clone(); + let mut second = self.second_service.clone(); + Box::pin(async move { + let res1 = first.call(req.clone()).await?; + let res2 = second.call(req.clone()).await?; + + let stackable = Stackable::new(res1, res2); + + Ok(stackable) + }) + } +} + +#[macro_export] +macro_rules! parallel_service { + ($service1:tt, $service2:tt) => { + $crate::pipeline::parallel::ParallelService::new($service1, $service2) + }; + ($op1:tt $(, $ops:tt)*) => { + $crate::pipeline::parallel::ParallelService::new( + $service1, + $crate::parallel_op!($($ops),*) + ) + }; +} diff --git a/rig-core/src/middlewares/rag.rs b/rig-core/src/middlewares/rag.rs index ebf3dcd..5af4f3f 100644 --- a/rig-core/src/middlewares/rag.rs +++ b/rig-core/src/middlewares/rag.rs @@ -9,10 +9,9 @@ use std::{ use serde::{Deserialize, Serialize}; use tower::Service; -use crate::{ - completion::CompletionRequest, - vector_store::{VectorStoreError, VectorStoreIndex}, -}; +use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex}; + +use super::ServiceError; pub struct RagService { vector_index: Arc, @@ -39,7 +38,7 @@ where T: Serialize + for<'a> Deserialize<'a> + Send, { type Response = RagResult; - type Error = VectorStoreError; + type Error = ServiceError; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -49,11 +48,16 @@ where fn call(&mut self, req: CompletionRequest) -> Self::Future { let vector_index = self.vector_index.clone(); let num_results = self.num_results; - let Some(prompt) = req.prompt.rag_text() else { - todo!("Handle error properly"); - }; - Box::pin(async move { vector_index.top_n(&prompt, num_results).await }) + Box::pin(async move { + let Some(prompt) = req.prompt.rag_text() else { + return Err(ServiceError::required_option_not_exists("rag_text")); + }; + + let res = vector_index.top_n(&prompt, num_results).await?; + + Ok(res) + }) } } diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs index 5a5fc82..e0ab19a 100644 --- a/rig-core/src/middlewares/tools.rs +++ b/rig-core/src/middlewares/tools.rs @@ -1,7 +1,6 @@ use crate::{ completion::{CompletionRequest, CompletionResponse}, - message::{AssistantContent, Message, ToolResultContent, UserContent}, - providers::{self}, + message::{AssistantContent, Message, ToolResultContent}, tool::{ToolSet, ToolSetError}, OneOrMany, }; @@ -83,49 +82,3 @@ where }) } } - -pub struct ToolService { - tools: Arc, -} - -type OpenAIResponse = CompletionResponse; - -impl Service for ToolService { - type Response = Message; - type Error = ToolSetError; - type Future = Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: OpenAIResponse) -> Self::Future { - let tools = self.tools.clone(); - - Box::pin(async move { - let crate::message::AssistantContent::ToolCall(tool_call) = req.choice.first() else { - unimplemented!("handle error"); - }; - - let Ok(res) = tools - .call( - &tool_call.function.name, - tool_call.function.arguments.to_string(), - ) - .await - else { - todo!("Implement proper error handling"); - }; - - Ok(Message::User { - content: OneOrMany::one(UserContent::tool_result( - tool_call.id, - OneOrMany::one(ToolResultContent::text(res)), - )), - }) - }) - } -} From 0ba10f0e3e91b9963f05a0afa20165827d4aeffd Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Wed, 9 Apr 2025 22:30:23 +0100 Subject: [PATCH 09/11] chore: satisfy ci --- rig-core/src/middlewares/extractor.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index 528828c..6d3de9d 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -22,6 +22,15 @@ where } } +impl Default for ExtractorLayer +where + T: for<'a> Deserialize<'a>, +{ + fn default() -> Self { + Self::new() + } +} + impl Layer for ExtractorLayer where T: for<'a> Deserialize<'a>, From bf652f47e88fd70da5ac614e47921bed6ad0297a Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Wed, 9 Apr 2025 23:21:43 +0100 Subject: [PATCH 10/11] fix: satisfy ci --- rig-core/examples/tower.rs | 25 +++++++++++++++++++------ rig-core/src/middlewares/completion.rs | 2 ++ rig-core/src/middlewares/components.rs | 2 ++ rig-core/src/middlewares/extractor.rs | 2 ++ rig-core/src/middlewares/tools.rs | 2 ++ 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/rig-core/examples/tower.rs b/rig-core/examples/tower.rs index 0949b11..f1643d1 100644 --- a/rig-core/examples/tower.rs +++ b/rig-core/examples/tower.rs @@ -1,11 +1,14 @@ use rig::{ - completion::CompletionRequestBuilder, + completion::{CompletionRequestBuilder, ToolDefinition}, middlewares::{ completion::{CompletionLayer, CompletionService}, tools::ToolLayer, }, providers::openai::Client, + tool::{Tool, ToolSet}, }; +use serde::{Deserialize, Serialize}; +use tower::Service; use tower::ServiceBuilder; #[tokio::main] @@ -13,11 +16,11 @@ async fn main() { let client = Client::from_env(); let model = client.completion_model("gpt-4o"); - let comp_layer = CompletionLayer::builder(model).build(); - let tool_layer = ToolLayer::new(vec![Add]); - let service = CompletionService::new(model); + let comp_layer = CompletionLayer::builder(model.clone()).build(); + let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add])); + let service = CompletionService::new(model.clone()); - let service = ServiceBuilder::new() + let mut service = ServiceBuilder::new() .layer(comp_layer) .layer(tool_layer) .service(service); @@ -32,6 +35,16 @@ async fn main() { #[derive(Deserialize, Serialize)] struct Add; +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + impl Tool for Add { const NAME: &'static str = "add"; @@ -40,7 +53,7 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ + serde_json::from_value(serde_json::json!({ "name": "add", "description": "Add x and y together", "parameters": { diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index 1b216f2..5cfe419 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -15,6 +15,7 @@ use crate::{ OneOrMany, }; +#[derive(Clone)] pub struct CompletionLayer { model: M, preamble: Option, @@ -155,6 +156,7 @@ where } } +#[derive(Clone)] pub struct CompletionLayerService { inner: S, model: M, diff --git a/rig-core/src/middlewares/components.rs b/rig-core/src/middlewares/components.rs index 5b1d5c4..e51bf20 100644 --- a/rig-core/src/middlewares/components.rs +++ b/rig-core/src/middlewares/components.rs @@ -7,6 +7,7 @@ use crate::{ message::{AssistantContent, Text}, }; +#[derive(Clone)] pub struct AwaitApprovalLayer { integration: T, } @@ -43,6 +44,7 @@ where } } +#[derive(Clone)] pub struct AwaitApprovalLayerService { inner: S, integration: T, diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index 6d3de9d..ddd0a29 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -9,6 +9,7 @@ use crate::{ use super::ServiceError; +#[derive(Clone)] pub struct ExtractorLayer { _t: PhantomData, } @@ -41,6 +42,7 @@ where } } +#[derive(Clone)] pub struct ExtractorLayerService { inner: S, _t: PhantomData, diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs index e0ab19a..09c9404 100644 --- a/rig-core/src/middlewares/tools.rs +++ b/rig-core/src/middlewares/tools.rs @@ -8,6 +8,7 @@ use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; use tower::{Layer, Service}; +#[derive(Clone)] pub struct ToolLayer { tools: Arc, } @@ -31,6 +32,7 @@ impl Layer for ToolLayer { } } +#[derive(Clone)] pub struct ToolLayerService { inner: S, tools: Arc, From a42ea437630e082f56b75fdf34ee92ef972f3c6c Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Mon, 14 Apr 2025 16:36:58 +0100 Subject: [PATCH 11/11] chore: amendments --- rig-core/src/middlewares/components.rs | 11 +++-------- rig-core/src/middlewares/mod.rs | 2 ++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/rig-core/src/middlewares/components.rs b/rig-core/src/middlewares/components.rs index e51bf20..c8fb3cb 100644 --- a/rig-core/src/middlewares/components.rs +++ b/rig-core/src/middlewares/components.rs @@ -123,12 +123,7 @@ where fn layer(&self, inner: S) -> Self::Service { let predicate = self.predicate.clone(); - AwaitApprovalLayerServiceWithPredicate::new( - inner, - self.integration.clone(), - predicate, - self._t, - ) + AwaitApprovalLayerServiceWithPredicate::new(inner, self.integration.clone(), predicate) } } @@ -145,12 +140,12 @@ where R: Fn(&D) -> Pin + Send>> + Clone + Send + 'static, D: Debug, { - pub fn new(inner: S, integration: T, predicate: R, _t: PhantomData) -> Self { + pub fn new(inner: S, integration: T, predicate: R) -> Self { Self { inner, integration, predicate, - _t, + _t: PhantomData, } } } diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs index 2e02d6f..53da6ab 100644 --- a/rig-core/src/middlewares/mod.rs +++ b/rig-core/src/middlewares/mod.rs @@ -27,6 +27,8 @@ pub enum ServiceError { RequiredOptionNotFound(String), #[error("{0}")] Json(#[from] serde_json::Error), + #[error("Custom error: {0}")] + Other(#[from] Box), } impl ServiceError {