diff --git a/rig-core/examples/tower.rs b/rig-core/examples/tower.rs new file mode 100644 index 0000000..0949b11 --- /dev/null +++ b/rig-core/examples/tower.rs @@ -0,0 +1,67 @@ +use rig::{ + completion::CompletionRequestBuilder, + middlewares::{ + completion::{CompletionLayer, CompletionService}, + tools::ToolLayer, + }, + providers::openai::Client, +}; +use tower::ServiceBuilder; + +#[tokio::main] +async fn main() { + let client = Client::from_env(); + let model = client.completion_model("gpt-4o"); + + let comp_layer = CompletionLayer::builder(model).build(); + let tool_layer = ToolLayer::new(vec![Add]); + let service = CompletionService::new(model); + + let service = ServiceBuilder::new() + .layer(comp_layer) + .layer(tool_layer) + .service(service); + + let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build(); + + let res = service.call(comp_request).await.unwrap(); + + println!("{res:?}"); +} + +#[derive(Deserialize, Serialize)] +struct Add; + +impl Tool for Add { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "add", + "description": "Add x and y together", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9a31fae..60456b0 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync { } /// Struct representing a general completion request that can be sent to a completion model provider. +#[derive(Clone)] pub struct CompletionRequest { /// The prompt to be sent to the completion model provider pub prompt: Message, diff --git a/rig-core/src/middlewares/build_completions.rs b/rig-core/src/middlewares/build_completions.rs new file mode 100644 index 0000000..f2d626d --- /dev/null +++ b/rig-core/src/middlewares/build_completions.rs @@ -0,0 +1,532 @@ +use std::{future::Future, pin::Pin, task::Poll}; +use tower::{Layer, Service}; + +use crate::{ + completion::{ + CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition, + }, + message::Message, +}; + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuilding; + +impl Layer for FinishBuilding { + type Service = FinishBuildingService; + + fn layer(&self, inner: S) -> Self::Service { + FinishBuildingService { inner } + } +} + +/// A Tower layer to finish building your `CompletionRequestBuilder`. +/// Not intended to be used directly. Use [`FinishBuilding`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct FinishBuildingService { + inner: S, +} + +impl Service<(M, Msg)> for FinishBuildingService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequest; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.build()) + }) + } +} + +/// A Tower layer to add documents to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayer { + documents: Vec, +} + +impl DocumentsLayer { + pub fn new(documents: Vec) -> Self { + Self { documents } + } +} + +impl Layer for DocumentsLayer { + type Service = DocumentsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + DocumentsLayerService { + inner, + documents: self.documents.clone(), + } + } +} + +/// A Tower service to add documents to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`DocumentsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct DocumentsLayerService { + inner: S, + documents: Vec, +} + +impl Service<(M, Msg)> for DocumentsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let documents = self.documents.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.documents(documents)) + }) + } +} + +/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayer { + temperature: f64, +} + +impl TemperatureLayer { + pub fn new(temperature: f64) -> Self { + Self { temperature } + } +} + +impl Layer for TemperatureLayer { + type Service = TemperatureLayerService; + + fn layer(&self, inner: S) -> Self::Service { + TemperatureLayerService { + inner, + temperature: self.temperature, + } + } +} + +/// A Tower service to add a temperature value to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`TemperatureLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct TemperatureLayerService { + inner: S, + temperature: f64, +} + +impl Service<(M, Msg)> for TemperatureLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let temperature = self.temperature; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.temperature(temperature)) + }) + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayer { + tools: Vec, +} + +impl ToolsLayer { + pub fn new(tools: Vec) -> Self { + Self { tools } + } +} + +impl Layer for ToolsLayer { + type Service = ToolsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ToolsLayerService { + inner, + tools: self.tools.clone(), + } + } +} + +/// A Tower service to add tools to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ToolsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ToolsLayerService { + inner: S, + tools: Vec, +} + +impl Service<(M, Msg)> for ToolsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let tools = self.tools.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.tools(tools)) + }) + } +} + +/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayer { + preamble: String, +} + +impl PreambleLayer { + pub fn new(preamble: String) -> Self { + Self { preamble } + } +} + +impl Layer for PreambleLayer { + type Service = PreambleLayerService; + + fn layer(&self, inner: S) -> Self::Service { + PreambleLayerService { + inner, + preamble: self.preamble.clone(), + } + } +} + +/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`PreambleLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct PreambleLayerService { + inner: S, + preamble: String, +} + +impl Service<(M, Msg)> for PreambleLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let preamble = self.preamble.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.preamble(preamble)) + }) + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayer { + additional_params: serde_json::Value, +} + +impl AdditionalParamsLayer { + pub fn new(additional_params: serde_json::Value) -> Self { + Self { additional_params } + } +} + +impl Layer for AdditionalParamsLayer { + type Service = AdditionalParamsLayerService; + + fn layer(&self, inner: S) -> Self::Service { + AdditionalParamsLayerService { + inner, + additional_params: self.additional_params.clone(), + } + } +} + +/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters). +/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct AdditionalParamsLayerService { + inner: S, + additional_params: serde_json::Value, +} + +impl Service<(M, Msg)> for AdditionalParamsLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let additional_params = self.additional_params.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.additional_params(additional_params)) + }) + } +} + +/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayer { + max_tokens: u64, +} + +impl MaxTokensLayer { + pub fn new(max_tokens: u64) -> Self { + Self { max_tokens } + } +} + +impl Layer for MaxTokensLayer { + type Service = MaxTokensLayerService; + + fn layer(&self, inner: S) -> Self::Service { + MaxTokensLayerService { + inner, + max_tokens: self.max_tokens, + } + } +} + +/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`MaxTokensLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct MaxTokensLayerService { + inner: S, + max_tokens: u64, +} + +impl Service<(M, Msg)> for MaxTokensLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let max_tokens = self.max_tokens; + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.max_tokens(max_tokens)) + }) + } +} + +/// A Tower layer to add a chat history to your `CompletionRequestBuilder`. +/// Intended to be used with [`CompletionRequestBuilderService`]. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayer { + chat_history: Vec, +} + +impl ChatHistoryLayer { + pub fn new(chat_history: Vec) -> Self { + Self { chat_history } + } +} + +impl Layer for ChatHistoryLayer { + type Service = ChatHistoryLayerService; + + fn layer(&self, inner: S) -> Self::Service { + ChatHistoryLayerService { + inner, + chat_history: self.chat_history.clone(), + } + } +} + +/// A Tower service to add a chat history to your `CompletionRequestBuilder`. +/// Not intended to be used directly - use [`ChatHistoryLayer`] instead. +/// +/// See [`CompletionRequestBuilderService`] for usage. +pub struct ChatHistoryLayerService { + inner: S, + chat_history: Vec, +} + +impl Service<(M, Msg)> for ChatHistoryLayerService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, + S: Service<(M, Msg), Response = CompletionRequestBuilder> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: (M, Msg)) -> Self::Future { + let mut inner = self.inner.clone(); + let chat_history = self.chat_history.clone(); + Box::pin(async move { + let Ok(res) = inner.call(req).await else { + todo!("Handle error properly"); + }; + + Ok(res.messages(chat_history)) + }) + } +} + +/// A `tower` service for building a completion request. +/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top). +/// +/// Usage: +/// ```rust +/// let agent = rig::providers::openai::Client::from_env(); +/// let model = agent.completion_model("gpt-4o"); +/// +/// let service = tower::ServiceBuilder::new() +/// .layer(FinishBuilding) +/// .layer(TemperatureLayer::new(0.0)) +/// .layer(PreambleLayer::new("You are a helpful assistant")) +/// .service(CompletionRequestBuilderService); +/// +/// let request = service.call((model, "Hello world!".to_string())).await.unwrap(); +/// +/// let res = request.send().await.unwrap(); +/// +/// println!("{res:?}"); +/// ``` +pub struct CompletionRequestBuilderService; + +impl Service<(M, Msg)> for CompletionRequestBuilderService +where + M: CompletionModel + Send + 'static, + Msg: Into + Send + 'static, +{ + type Response = CompletionRequestBuilder; + type Error = (); + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future { + Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) }) + } +} diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index df6714a..1b216f2 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -224,6 +224,7 @@ where /// A completion model as a Tower service. /// /// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service. +#[derive(Clone)] pub struct CompletionService { /// The model itself. model: M, diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index 938e08d..528828c 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -1,68 +1,54 @@ -use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; - -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; use tower::{Layer, Service}; use crate::{ - completion::{CompletionModel, CompletionRequest, CompletionResponse}, - extractor::{ExtractionError, Extractor}, + completion::{CompletionRequest, CompletionResponse}, message::{AssistantContent, Text}, }; -pub struct ExtractorLayer -where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, -{ - ext: Arc>, +use super::ServiceError; + +pub struct ExtractorLayer { + _t: PhantomData, } -impl ExtractorLayer +impl ExtractorLayer where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, + T: for<'a> Deserialize<'a>, { - pub fn new(ext: Extractor) -> Self { - Self { ext: Arc::new(ext) } + pub fn new() -> Self { + Self { _t: PhantomData } } } -impl Layer for ExtractorLayer +impl Layer for ExtractorLayer where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, + T: for<'a> Deserialize<'a>, { - type Service = ExtractorLayerService; + type Service = ExtractorLayerService; fn layer(&self, inner: S) -> Self::Service { - ExtractorLayerService { - inner, - ext: Arc::clone(&self.ext), - } + ExtractorLayerService { inner, _t: self._t } } } -pub struct ExtractorLayerService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ +pub struct ExtractorLayerService { inner: S, - ext: Arc>, + _t: PhantomData, } -impl Service for ExtractorLayerService +impl Service for ExtractorLayerService where - S: Service> + S: Service, Error = ServiceError> + Clone + Send + 'static, S::Future: Send, - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, + F: 'static, + T: for<'a> Deserialize<'a> + 'static, { type Response = T; - type Error = ExtractionError; + type Error = ServiceError; type Future = Pin> + Send>>; fn poll_ready( @@ -73,63 +59,18 @@ where } fn call(&mut self, req: CompletionRequest) -> Self::Future { - let ext = self.ext.clone(); let mut inner = self.inner.clone(); Box::pin(async move { - let Ok(res) = inner.call(req).await else { - todo!("Properly handle error"); - }; + let res = inner.call(req).await?; let AssistantContent::Text(Text { text }) = res.choice.first() else { todo!("Handle errors properly"); }; - ext.extract(&text).await + let obj = serde_json::from_str::(&text)?; + + Ok(obj) }) } } - -pub struct ExtractorService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ - ext: Arc>, -} - -impl ExtractorService -where - M: CompletionModel + 'static, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, -{ - pub fn new(ext: Extractor) -> Self { - Self { ext: Arc::new(ext) } - } -} - -impl Service for ExtractorService -where - M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, -{ - type Response = T; - type Error = ExtractionError; - type Future = Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: CompletionRequest) -> Self::Future { - let ext = self.ext.clone(); - let Some(req) = req.prompt.rag_text() else { - todo!("Handle error properly"); - }; - - Box::pin(async move { ext.extract(&req).await }) - } -} diff --git a/rig-core/src/middlewares/mod.rs b/rig-core/src/middlewares/mod.rs index f99abea..2e02d6f 100644 --- a/rig-core/src/middlewares/mod.rs +++ b/rig-core/src/middlewares/mod.rs @@ -1,10 +1,15 @@ use thiserror::Error; -use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError}; +use crate::{ + completion::CompletionError, extractor::ExtractionError, tool::ToolSetError, + vector_store::VectorStoreError, +}; +pub mod build_completions; pub mod completion; pub mod components; pub mod extractor; +pub mod parallel; pub mod rag; pub mod tools; @@ -16,4 +21,17 @@ pub enum ServiceError { CompletionError(#[from] CompletionError), #[error("{0}")] ToolSetError(#[from] ToolSetError), + #[error("{0}")] + VectorStoreError(#[from] VectorStoreError), + #[error("Value required but was null: {0}")] + RequiredOptionNotFound(String), + #[error("{0}")] + Json(#[from] serde_json::Error), +} + +impl ServiceError { + pub fn required_option_not_exists>(val: S) -> Self { + let val: String = val.into(); + Self::RequiredOptionNotFound(val) + } } diff --git a/rig-core/src/middlewares/parallel.rs b/rig-core/src/middlewares/parallel.rs new file mode 100644 index 0000000..1d619be --- /dev/null +++ b/rig-core/src/middlewares/parallel.rs @@ -0,0 +1,88 @@ +use std::{future::Future, pin::Pin}; + +use tower::Service; + +use crate::completion::CompletionRequest; + +use super::ServiceError; + +pub struct Stackable { + pub inner: A, + pub outer: B, +} + +impl Stackable { + pub fn new(inner: A, outer: B) -> Self { + Self { inner, outer } + } + + pub fn take_values(self) -> (A, B) { + (self.inner, self.outer) + } +} + +#[derive(Clone)] +pub struct ParallelService { + first_service: S, + second_service: T, +} + +impl ParallelService +where + S: Service, + T: Service, +{ + pub fn new(first_service: S, second_service: T) -> Self { + Self { + first_service, + second_service, + } + } +} + +impl Service for ParallelService +where + S: Service + Clone + Send + 'static, + S::Future: Send, + S::Response: Send + 'static, + T: Service + Clone + Send + 'static, + T::Future: Send, + T::Response: Send + 'static, +{ + type Response = Stackable; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let mut first = self.first_service.clone(); + let mut second = self.second_service.clone(); + Box::pin(async move { + let res1 = first.call(req.clone()).await?; + let res2 = second.call(req.clone()).await?; + + let stackable = Stackable::new(res1, res2); + + Ok(stackable) + }) + } +} + +#[macro_export] +macro_rules! parallel_service { + ($service1:tt, $service2:tt) => { + $crate::pipeline::parallel::ParallelService::new($service1, $service2) + }; + ($op1:tt $(, $ops:tt)*) => { + $crate::pipeline::parallel::ParallelService::new( + $service1, + $crate::parallel_op!($($ops),*) + ) + }; +} diff --git a/rig-core/src/middlewares/rag.rs b/rig-core/src/middlewares/rag.rs index ebf3dcd..5af4f3f 100644 --- a/rig-core/src/middlewares/rag.rs +++ b/rig-core/src/middlewares/rag.rs @@ -9,10 +9,9 @@ use std::{ use serde::{Deserialize, Serialize}; use tower::Service; -use crate::{ - completion::CompletionRequest, - vector_store::{VectorStoreError, VectorStoreIndex}, -}; +use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex}; + +use super::ServiceError; pub struct RagService { vector_index: Arc, @@ -39,7 +38,7 @@ where T: Serialize + for<'a> Deserialize<'a> + Send, { type Response = RagResult; - type Error = VectorStoreError; + type Error = ServiceError; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -49,11 +48,16 @@ where fn call(&mut self, req: CompletionRequest) -> Self::Future { let vector_index = self.vector_index.clone(); let num_results = self.num_results; - let Some(prompt) = req.prompt.rag_text() else { - todo!("Handle error properly"); - }; - Box::pin(async move { vector_index.top_n(&prompt, num_results).await }) + Box::pin(async move { + let Some(prompt) = req.prompt.rag_text() else { + return Err(ServiceError::required_option_not_exists("rag_text")); + }; + + let res = vector_index.top_n(&prompt, num_results).await?; + + Ok(res) + }) } } diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs index 5a5fc82..e0ab19a 100644 --- a/rig-core/src/middlewares/tools.rs +++ b/rig-core/src/middlewares/tools.rs @@ -1,7 +1,6 @@ use crate::{ completion::{CompletionRequest, CompletionResponse}, - message::{AssistantContent, Message, ToolResultContent, UserContent}, - providers::{self}, + message::{AssistantContent, Message, ToolResultContent}, tool::{ToolSet, ToolSetError}, OneOrMany, }; @@ -83,49 +82,3 @@ where }) } } - -pub struct ToolService { - tools: Arc, -} - -type OpenAIResponse = CompletionResponse; - -impl Service for ToolService { - type Response = Message; - type Error = ToolSetError; - type Future = Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: OpenAIResponse) -> Self::Future { - let tools = self.tools.clone(); - - Box::pin(async move { - let crate::message::AssistantContent::ToolCall(tool_call) = req.choice.first() else { - unimplemented!("handle error"); - }; - - let Ok(res) = tools - .call( - &tool_call.function.name, - tool_call.function.arguments.to_string(), - ) - .await - else { - todo!("Implement proper error handling"); - }; - - Ok(Message::User { - content: OneOrMany::one(UserContent::tool_result( - tool_call.id, - OneOrMany::one(ToolResultContent::text(res)), - )), - }) - }) - } -}