diff --git a/Cargo.lock b/Cargo.lock index 7c21d09..c4be405 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8612,6 +8612,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-test", + "tower 0.5.2", "tracing", "tracing-subscriber", "worker", diff --git a/Cargo.toml b/Cargo.toml index a7309db..eb99783 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "rig-qdrant", "rig-core/rig-core-derive", "rig-sqlite", - "rig-eternalai", "rig-fastembed", + "rig-eternalai", + "rig-fastembed", "rig-surrealdb", ] diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index cedf63c..d31ac99 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -39,6 +39,7 @@ bytes = "1.9.0" async-stream = "0.3.6" mime_guess = { version = "2.0.5" } base64 = { version = "0.22.1" } +tower = "0.5.2" [dev-dependencies] diff --git a/rig-core/examples/tower.rs b/rig-core/examples/tower.rs new file mode 100644 index 0000000..f1643d1 --- /dev/null +++ b/rig-core/examples/tower.rs @@ -0,0 +1,80 @@ +use rig::{ + completion::{CompletionRequestBuilder, ToolDefinition}, + middlewares::{ + completion::{CompletionLayer, CompletionService}, + tools::ToolLayer, + }, + providers::openai::Client, + tool::{Tool, ToolSet}, +}; +use serde::{Deserialize, Serialize}; +use tower::Service; +use tower::ServiceBuilder; + +#[tokio::main] +async fn main() { + let client = Client::from_env(); + let model = client.completion_model("gpt-4o"); + + let comp_layer = CompletionLayer::builder(model.clone()).build(); + let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add])); + let service = CompletionService::new(model.clone()); + + let mut service = ServiceBuilder::new() + .layer(comp_layer) + .layer(tool_layer) + .service(service); + + let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build(); + + let res = service.call(comp_request).await.unwrap(); + + println!("{res:?}"); +} + +#[derive(Deserialize, Serialize)] +struct Add; + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +impl Tool for Add { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(serde_json::json!({ + "name": "add", + "description": "Add x and y together", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..60456b0 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync { } /// Struct representing a general completion request that can be sent to a completion model provider. +#[derive(Clone)] pub struct CompletionRequest { /// The prompt to be sent to the completion model provider pub prompt: Message, diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 300c962..a13d184 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -91,6 +91,7 @@ pub mod extractor; pub mod image_generation; pub(crate) mod json_utils; pub mod loaders; +pub mod middlewares; pub mod one_or_many; pub mod pipeline; pub mod providers; diff --git a/rig-core/src/middlewares/build_completions.rs b/rig-core/src/middlewares/build_completions.rs new file mode 100644 index 0000000..f2d626d --- /dev/null +++ b/rig-core/src/middlewares/build_completions.rs @@ -0,0 +1,532 @@ +use std::{future::Future, pin::Pin, task::Poll}; +use tower::{Layer, Service}; + +use crate::{ + completion::{ + CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition, + }, + message::Message, +}; + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuilding; + +impl Layer for FinishBuilding { + type Service = FinishBuildingService; + + fn layer(&self, inner: S) -> Self::Service { + FinishBuildingService { inner } + } +} + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Not intended to be used directly. Use [`FinishBuilding`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuildingService { + inner: S, +} + +impl Service<(M, Msg)> for FinishBuildingService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequest; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.build()) + }) + } +} + +/// A Tower layer to add documents to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayer { + documents: Vec, +} + +impl DocumentsLayer { + pub fn new(documents: Vec) -> Self { + Self { documents } + } +} + +impl Layer for DocumentsLayer { + type Service = DocumentsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + DocumentsLayerService { + inner, + documents: self.documents.clone(), + } + } +} + +/// A Tower service to add documents to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`DocumentsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayerService { + inner: S, + documents: Vec, +} + +impl Service<(M, Msg)> for DocumentsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let documents = self.documents.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.documents(documents)) + }) + } +} + +/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayer { + temperature: f64, +} + +impl TemperatureLayer { + pub fn new(temperature: f64) -> Self { + Self { temperature } + } +} + +impl Layer for TemperatureLayer { + type Service = TemperatureLayerService; + + fn layer(&self, inner: S) -> Self::Service { + TemperatureLayerService { + inner, + temperature: self.temperature, + } + } +} + +/// A Tower service to add a temperature value to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`TemperatureLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayerService { + inner: S, + temperature: f64, +} + +impl Service<(M, Msg)> for TemperatureLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let temperature = self.temperature; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.temperature(temperature)) + }) + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayer { + tools: Vec, +} + +impl ToolsLayer { + pub fn new(tools: Vec) -> Self { + Self { tools } + } +} + +impl Layer for ToolsLayer { + type Service = ToolsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ToolsLayerService { + inner, + tools: self.tools.clone(), + } + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ToolsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayerService { + inner: S, + tools: Vec, +} + +impl Service<(M, Msg)> for ToolsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let tools = self.tools.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.tools(tools)) + }) + } +} + +/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayer { + preamble: String, +} + +impl PreambleLayer { + pub fn new(preamble: String) -> Self { + Self { preamble } + } +} + +impl Layer for PreambleLayer { + type Service = PreambleLayerService; + + fn layer(&self, inner: S) -> Self::Service { + PreambleLayerService { + inner, + preamble: self.preamble.clone(), + } + } +} + +/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`PreambleLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayerService { + inner: S, + preamble: String, +} + +impl Service<(M, Msg)> for PreambleLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let preamble = self.preamble.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.preamble(preamble)) + }) + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayer { + additional_params: serde_json::Value, +} + +impl AdditionalParamsLayer { + pub fn new(additional_params: serde_json::Value) -> Self { + Self { additional_params } + } +} + +impl Layer for AdditionalParamsLayer { + type Service = AdditionalParamsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + AdditionalParamsLayerService { + inner, + additional_params: self.additional_params.clone(), + } + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayerService { + inner: S, + additional_params: serde_json::Value, +} + +impl Service<(M, Msg)> for AdditionalParamsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let additional_params = self.additional_params.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.additional_params(additional_params)) + }) + } +} + +/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayer { + max_tokens: u64, +} + +impl MaxTokensLayer { + pub fn new(max_tokens: u64) -> Self { + Self { max_tokens } + } +} + +impl Layer for MaxTokensLayer { + type Service = MaxTokensLayerService; + + fn layer(&self, inner: S) -> Self::Service { + MaxTokensLayerService { + inner, + max_tokens: self.max_tokens, + } + } +} + +/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`MaxTokensLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayerService { + inner: S, + max_tokens: u64, +} + +impl Service<(M, Msg)> for MaxTokensLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let max_tokens = self.max_tokens; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.max_tokens(max_tokens)) + }) + } +} + +/// A Tower layer to add a chat history to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayer { + chat_history: Vec, +} + +impl ChatHistoryLayer { + pub fn new(chat_history: Vec) -> Self { + Self { chat_history } + } +} + +impl Layer for ChatHistoryLayer { + type Service = ChatHistoryLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ChatHistoryLayerService { + inner, + chat_history: self.chat_history.clone(), + } + } +} + +/// A Tower service to add a chat history to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ChatHistoryLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayerService { + inner: S, + chat_history: Vec, +} + +impl Service<(M, Msg)> for ChatHistoryLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let chat_history = self.chat_history.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.messages(chat_history)) + }) + } +} + +/// A `tower` service for building a completion request. +/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top). +/// +/// Usage: +/// ```rust +/// let agent = rig::providers::openai::Client::from_env(); +/// let model = agent.completion_model("gpt-4o"); +/// +/// let service = tower::ServiceBuilder::new() +/// .layer(FinishBuilding) +/// .layer(TemperatureLayer::new(0.0)) +/// .layer(PreambleLayer::new("You are a helpful assistant")) +/// .service(CompletionRequestBuilderService); +/// +/// let request = service.call((model, "Hello world!".to_string())).await.unwrap(); +/// +/// let res = request.send().await.unwrap(); +/// +/// println!("{res:?}"); +/// ``` +pub struct CompletionRequestBuilderService; + +impl Service<(M, Msg)> for CompletionRequestBuilderService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future { + Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) }) + } +} diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs new file mode 100644 index 0000000..5cfe419 --- /dev/null +++ b/rig-core/src/middlewares/completion.rs @@ -0,0 +1,262 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use tower::{Layer, Service}; + +use crate::{ + completion::{ + CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, + CompletionResponse, Document, ToolDefinition, + }, + message::{Message, ToolResultContent, UserContent}, + OneOrMany, +}; + +#[derive(Clone)] +pub struct CompletionLayer { + model: M, + preamble: Option, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl CompletionLayer +where + M: CompletionModel, +{ + pub fn builder(model: M) -> CompletionLayerBuilder { + CompletionLayerBuilder::new(model) + } +} + +#[derive(Default)] +pub struct CompletionLayerBuilder { + model: M, + preamble: Option, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl CompletionLayerBuilder +where + M: CompletionModel, +{ + pub fn new(model: M) -> Self { + Self { + model, + preamble: None, + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + additional_params: None, + } + } + + pub fn preamble(mut self, preamble: String) -> Self { + self.preamble = Some(preamble); + + self + } + + pub fn preamble_opt(mut self, preamble: Option) -> Self { + self.preamble = preamble; + + self + } + + pub fn documents(mut self, documents: Vec) -> Self { + self.documents = documents; + + self + } + + pub fn tools(mut self, tools: Vec) -> Self { + self.tools = tools; + + self + } + + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + + self + } + + pub fn temperature_opt(mut self, temperature: Option) -> Self { + self.temperature = temperature; + + self + } + + pub fn max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + + self + } + + pub fn max_tokens_opt(mut self, max_tokens: Option) -> Self { + self.max_tokens = max_tokens; + + self + } + + pub fn additional_params(mut self, params: serde_json::Value) -> Self { + self.additional_params = Some(params); + + self + } + + pub fn additional_params_opt(mut self, params: Option) -> Self { + self.additional_params = params; + + self + } + + pub fn build(self) -> CompletionLayer { + CompletionLayer { + model: self.model, + preamble: self.preamble, + + documents: self.documents, + tools: self.tools, + temperature: self.temperature, + max_tokens: self.max_tokens, + additional_params: self.additional_params, + } + } +} + +impl Layer for CompletionLayer +where + M: CompletionModel, +{ + type Service = CompletionLayerService; + fn layer(&self, inner: S) -> Self::Service { + CompletionLayerService { + inner, + model: self.model.clone(), + preamble: self.preamble.clone(), + + documents: self.documents.clone(), + tools: self.tools.clone(), + temperature: self.temperature, + max_tokens: self.max_tokens, + additional_params: self.additional_params.clone(), + } + } +} + +#[derive(Clone)] +pub struct CompletionLayerService { + inner: S, + model: M, + preamble: Option, + documents: Vec, + tools: Vec, + temperature: Option, + max_tokens: Option, + additional_params: Option, +} + +impl Service for CompletionLayerService +where + M: CompletionModel + 'static, + S: Service, String, ToolResultContent)> + + Clone + + Send + + 'static, + S::Future: Send, +{ + type Response = CompletionResponse; + type Error = CompletionError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let model = self.model.clone(); + let preamble = self.preamble.clone(); + let documents = self.documents.clone(); + let temperature = self.temperature; + let tools = self.tools.clone(); + let max_tokens = self.max_tokens; + let additional_params = self.additional_params.clone(); + + Box::pin(async move { + let Ok((messages, id, tool_content)) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let tool_result_message = Message::User { + content: OneOrMany::one(UserContent::tool_result(id, OneOrMany::one(tool_content))), + }; + + let mut req = CompletionRequestBuilder::new(model.clone(), tool_result_message) + .documents(documents.clone()) + .tools(tools.clone()) + .messages(messages) + .temperature_opt(temperature) + .max_tokens_opt(max_tokens) + .additional_params_opt(additional_params.clone()); + + if let Some(preamble) = preamble.clone() { + req = req.preamble(preamble); + } + + let req = req.build(); + + model.completion(req).await + }) + } +} + +/// A completion model as a Tower service. +/// +/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service. +#[derive(Clone)] +pub struct CompletionService { + /// The model itself. + model: M, +} + +impl CompletionService +where + M: CompletionModel, +{ + pub fn new(model: M) -> Self { + Self { model } + } +} + +impl Service for CompletionService +where + M: CompletionModel + 'static, +{ + type Response = CompletionResponse; + type Error = CompletionError; + + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let model = self.model.clone(); + + Box::pin(async move { model.completion(req).await }) + } +} diff --git a/rig-core/src/middlewares/components.rs b/rig-core/src/middlewares/components.rs new file mode 100644 index 0000000..c8fb3cb --- /dev/null +++ b/rig-core/src/middlewares/components.rs @@ -0,0 +1,247 @@ +use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, task::Poll}; + +use tower::{Layer, Service}; + +use crate::{ + completion::{CompletionRequest, CompletionResponse}, + message::{AssistantContent, Text}, +}; + +#[derive(Clone)] +pub struct AwaitApprovalLayer { + integration: T, +} + +impl AwaitApprovalLayer +where + T: HumanInTheLoop + Clone, +{ + pub fn new(integration: T) -> Self { + Self { integration } + } + + pub fn with_predicate(self, predicate: R) -> AwaitApprovalLayerWithPredicate + where + D: Debug, + R: Fn() -> Pin + Send>> + Clone + Send + 'static, + { + AwaitApprovalLayerWithPredicate { + integration: self.integration, + predicate, + _t: PhantomData, + } + } +} + +impl Layer for AwaitApprovalLayer +where + T: HumanInTheLoop + Clone, +{ + type Service = AwaitApprovalLayerService; + + fn layer(&self, inner: S) -> Self::Service { + AwaitApprovalLayerService::new(inner, self.integration.clone()) + } +} + +#[derive(Clone)] +pub struct AwaitApprovalLayerService { + inner: S, + integration: T, +} + +impl AwaitApprovalLayerService +where + T: HumanInTheLoop, +{ + pub fn new(inner: S, integration: T) -> Self { + Self { inner, integration } + } +} + +impl Service for AwaitApprovalLayerService +where + S: Service> + Clone + Send + 'static, + S::Future: Send, + Response: Clone + 'static + Send, + T: HumanInTheLoop + Clone + Send + 'static, +{ + type Response = CompletionResponse; + type Error = bool; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let await_approval_loop = self.integration.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle error properly"); + }; + + if await_approval_loop.send_message(&text).await.is_err() { + todo!("Handle error properly"); + } + + let Ok(bool) = await_approval_loop.await_approval().await else { + todo!("Handle error properly"); + }; + + if bool { + Ok(res) + } else { + todo!("Handle error properly - we should abort the pipeline here if the user wants to abort"); + } + }) + } +} + +pub struct AwaitApprovalLayerWithPredicate { + integration: T, + predicate: R, + _t: PhantomData

, +} + +impl Layer for AwaitApprovalLayerWithPredicate +where + T: HumanInTheLoop + Clone, + D: Debug, + R: Fn(&D) -> Pin + Send>> + Clone + Send + 'static, +{ + type Service = AwaitApprovalLayerServiceWithPredicate; + + fn layer(&self, inner: S) -> Self::Service { + let predicate = self.predicate.clone(); + AwaitApprovalLayerServiceWithPredicate::new(inner, self.integration.clone(), predicate) + } +} + +pub struct AwaitApprovalLayerServiceWithPredicate { + inner: S, + integration: T, + predicate: R, + _t: PhantomData, +} + +impl AwaitApprovalLayerServiceWithPredicate +where + T: HumanInTheLoop, + R: Fn(&D) -> Pin + Send>> + Clone + Send + 'static, + D: Debug, +{ + pub fn new(inner: S, integration: T, predicate: R) -> Self { + Self { + inner, + integration, + predicate, + _t: PhantomData, + } + } +} + +impl Service + for AwaitApprovalLayerServiceWithPredicate +where + R: Fn(&CompletionResponse) -> Pin + Send>> + + Clone + + Send + + 'static, + S: Service> + Clone + Send + 'static, + S::Future: Send, + Response: Clone + 'static + Send, + T: HumanInTheLoop + Clone + Send + 'static, +{ + type Response = CompletionResponse; + type Error = bool; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let await_approval_loop = self.integration.clone(); + let predicate = self.predicate.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + if predicate(&res).await { + return Ok(res); + } + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle error properly"); + }; + + if await_approval_loop.send_message(&text).await.is_err() { + todo!("Handle error properly"); + } + + let Ok(bool) = await_approval_loop.await_approval().await else { + todo!("Handle error properly"); + }; + + if bool { + Ok(res) + } else { + todo!("Handle error properly - we should abort the pipeline here if the user wants to abort"); + } + }) + } +} + +pub trait HumanInTheLoop { + fn send_message( + &self, + res: &str, + ) -> impl Future>> + Send; + fn await_approval( + &self, + ) -> impl Future>> + Send; +} + +pub struct Stdout; + +impl HumanInTheLoop for Stdout { + async fn send_message(&self, res: &str) -> Result<(), Box> { + print!( + "Current result: {res} + + Would you like to approve this step? [Y/n]" + ); + + Ok(()) + } + + async fn await_approval(&self) -> Result> { + let mut string = String::new(); + + loop { + std::io::stdin().read_line(&mut string).unwrap(); + + match string.to_lowercase().trim() { + "y" | "yes" => break Ok(true), + "n" | "no" => break Ok(false), + _ => println!("Please respond with 'y' or 'n'."), + } + } + } +} diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs new file mode 100644 index 0000000..ddd0a29 --- /dev/null +++ b/rig-core/src/middlewares/extractor.rs @@ -0,0 +1,87 @@ +use serde::Deserialize; +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; +use tower::{Layer, Service}; + +use crate::{ + completion::{CompletionRequest, CompletionResponse}, + message::{AssistantContent, Text}, +}; + +use super::ServiceError; + +#[derive(Clone)] +pub struct ExtractorLayer { + _t: PhantomData, +} + +impl ExtractorLayer +where + T: for<'a> Deserialize<'a>, +{ + pub fn new() -> Self { + Self { _t: PhantomData } + } +} + +impl Default for ExtractorLayer +where + T: for<'a> Deserialize<'a>, +{ + fn default() -> Self { + Self::new() + } +} + +impl Layer for ExtractorLayer +where + T: for<'a> Deserialize<'a>, +{ + type Service = ExtractorLayerService; + fn layer(&self, inner: S) -> Self::Service { + ExtractorLayerService { inner, _t: self._t } + } +} + +#[derive(Clone)] +pub struct ExtractorLayerService { + inner: S, + _t: PhantomData, +} + +impl Service for ExtractorLayerService +where + S: Service, Error = ServiceError> + + Clone + + Send + + 'static, + S::Future: Send, + F: 'static, + T: for<'a> Deserialize<'a> + 'static, +{ + type Response = T; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + + Box::pin(async move { + let res = inner.call(req).await?; + + let AssistantContent::Text(Text { text }) = res.choice.first() else { + todo!("Handle errors properly"); + }; + + let obj = serde_json::from_str::(&text)?; + + Ok(obj) + }) + } +} diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs new file mode 100644 index 0000000..53da6ab --- /dev/null +++ b/rig-core/src/middlewares/mod.rs @@ -0,0 +1,39 @@ +use thiserror::Error; + +use crate::{ + completion::CompletionError, extractor::ExtractionError, tool::ToolSetError, + vector_store::VectorStoreError, +}; + +pub mod build_completions; +pub mod completion; +pub mod components; +pub mod extractor; +pub mod parallel; +pub mod rag; +pub mod tools; + +#[derive(Debug, Error)] +pub enum ServiceError { + #[error("{0}")] + ExtractionError(#[from] ExtractionError), + #[error("{0}")] + CompletionError(#[from] CompletionError), + #[error("{0}")] + ToolSetError(#[from] ToolSetError), + #[error("{0}")] + VectorStoreError(#[from] VectorStoreError), + #[error("Value required but was null: {0}")] + RequiredOptionNotFound(String), + #[error("{0}")] + Json(#[from] serde_json::Error), + #[error("Custom error: {0}")] + Other(#[from] Box), +} + +impl ServiceError { + pub fn required_option_not_exists>(val: S) -> Self { + let val: String = val.into(); + Self::RequiredOptionNotFound(val) + } +} diff --git a/rig-core/src/middlewares/parallel.rs b/rig-core/src/middlewares/parallel.rs new file mode 100644 index 0000000..1d619be --- /dev/null +++ b/rig-core/src/middlewares/parallel.rs @@ -0,0 +1,88 @@ +use std::{future::Future, pin::Pin}; + +use tower::Service; + +use crate::completion::CompletionRequest; + +use super::ServiceError; + +pub struct Stackable { + pub inner: A, + pub outer: B, +} + +impl Stackable { + pub fn new(inner: A, outer: B) -> Self { + Self { inner, outer } + } + + pub fn take_values(self) -> (A, B) { + (self.inner, self.outer) + } +} + +#[derive(Clone)] +pub struct ParallelService { + first_service: S, + second_service: T, +} + +impl ParallelService +where + S: Service, + T: Service, +{ + pub fn new(first_service: S, second_service: T) -> Self { + Self { + first_service, + second_service, + } + } +} + +impl Service for ParallelService +where + S: Service + Clone + Send + 'static, + S::Future: Send, + S::Response: Send + 'static, + T: Service + Clone + Send + 'static, + T::Future: Send, + T::Response: Send + 'static, +{ + type Response = Stackable; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut first = self.first_service.clone(); + let mut second = self.second_service.clone(); + Box::pin(async move { + let res1 = first.call(req.clone()).await?; + let res2 = second.call(req.clone()).await?; + + let stackable = Stackable::new(res1, res2); + + Ok(stackable) + }) + } +} + +#[macro_export] +macro_rules! parallel_service { + ($service1:tt, $service2:tt) => { + $crate::pipeline::parallel::ParallelService::new($service1, $service2) + }; + ($op1:tt $(, $ops:tt)*) => { + $crate::pipeline::parallel::ParallelService::new( + $service1, + $crate::parallel_op!($($ops),*) + ) + }; +} diff --git a/rig-core/src/middlewares/rag.rs b/rig-core/src/middlewares/rag.rs new file mode 100644 index 0000000..5af4f3f --- /dev/null +++ b/rig-core/src/middlewares/rag.rs @@ -0,0 +1,64 @@ +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use serde::{Deserialize, Serialize}; +use tower::Service; + +use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex}; + +use super::ServiceError; + +pub struct RagService { + vector_index: Arc, + num_results: usize, + _phantom: PhantomData, +} + +impl RagService +where + V: VectorStoreIndex, +{ + pub fn new(vector_index: V, num_results: usize) -> Self { + Self { + vector_index: Arc::new(vector_index), + num_results, + _phantom: PhantomData, + } + } +} + +impl Service for RagService +where + V: VectorStoreIndex + 'static, + T: Serialize + for<'a> Deserialize<'a> + Send, +{ + type Response = RagResult; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let vector_index = self.vector_index.clone(); + let num_results = self.num_results; + + Box::pin(async move { + let Some(prompt) = req.prompt.rag_text() else { + return Err(ServiceError::required_option_not_exists("rag_text")); + }; + + let res = vector_index.top_n(&prompt, num_results).await?; + + Ok(res) + }) + } +} + +pub type RagResult = Vec<(f64, String, T)>; diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs new file mode 100644 index 0000000..09c9404 --- /dev/null +++ b/rig-core/src/middlewares/tools.rs @@ -0,0 +1,86 @@ +use crate::{ + completion::{CompletionRequest, CompletionResponse}, + message::{AssistantContent, Message, ToolResultContent}, + tool::{ToolSet, ToolSetError}, + OneOrMany, +}; +use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; + +use tower::{Layer, Service}; + +#[derive(Clone)] +pub struct ToolLayer { + tools: Arc, +} + +impl ToolLayer { + pub fn new(tools: ToolSet) -> Self { + Self { + tools: Arc::new(tools), + } + } +} + +impl Layer for ToolLayer { + type Service = ToolLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ToolLayerService { + inner, + tools: Arc::clone(&self.tools), + } + } +} + +#[derive(Clone)] +pub struct ToolLayerService { + inner: S, + tools: Arc, +} + +impl Service for ToolLayerService +where + S: Service> + Clone + Send + 'static, + T: Send + 'static, + S::Future: Send, +{ + type Response = (Vec, String, ToolResultContent); + type Error = ToolSetError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut inner = self.inner.clone(); + let tools = self.tools.clone(); + let mut messages = req.chat_history.clone(); + + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + let AssistantContent::ToolCall(tool_call) = res.choice.first() else { + todo!("Handle error properly"); + }; + + messages.push(Message::Assistant { + content: OneOrMany::one(AssistantContent::ToolCall(tool_call.clone())), + }); + + let Ok(res) = tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await + else { + todo!("Implement proper error handling"); + }; + + Ok((messages, tool_call.id, ToolResultContent::text(res))) + }) + } +}