diff --git a/rig-core/examples/tower.rs b/rig-core/examples/tower.rs index 0949b11..f1643d1 100644 --- a/rig-core/examples/tower.rs +++ b/rig-core/examples/tower.rs @@ -1,11 +1,14 @@ use rig::{ - completion::CompletionRequestBuilder, + completion::{CompletionRequestBuilder, ToolDefinition}, middlewares::{ completion::{CompletionLayer, CompletionService}, tools::ToolLayer, }, providers::openai::Client, + tool::{Tool, ToolSet}, }; +use serde::{Deserialize, Serialize}; +use tower::Service; use tower::ServiceBuilder; #[tokio::main] @@ -13,11 +16,11 @@ async fn main() { let client = Client::from_env(); let model = client.completion_model("gpt-4o"); - let comp_layer = CompletionLayer::builder(model).build(); - let tool_layer = ToolLayer::new(vec![Add]); - let service = CompletionService::new(model); + let comp_layer = CompletionLayer::builder(model.clone()).build(); + let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add])); + let service = CompletionService::new(model.clone()); - let service = ServiceBuilder::new() + let mut service = ServiceBuilder::new() .layer(comp_layer) .layer(tool_layer) .service(service); @@ -32,6 +35,16 @@ async fn main() { #[derive(Deserialize, Serialize)] struct Add; +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + impl Tool for Add { const NAME: &'static str = "add"; @@ -40,7 +53,7 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ + serde_json::from_value(serde_json::json!({ "name": "add", "description": "Add x and y together", "parameters": { diff --git a/rig-core/src/middlewares/completion.rs b/rig-core/src/middlewares/completion.rs index 1b216f2..5cfe419 100644 --- a/rig-core/src/middlewares/completion.rs +++ b/rig-core/src/middlewares/completion.rs @@ -15,6 +15,7 @@ use crate::{ OneOrMany, }; +#[derive(Clone)] pub struct CompletionLayer { model: M, preamble: Option, @@ -155,6 +156,7 @@ where } } +#[derive(Clone)] pub struct CompletionLayerService { inner: S, model: M, diff --git a/rig-core/src/middlewares/components.rs b/rig-core/src/middlewares/components.rs index 5b1d5c4..e51bf20 100644 --- a/rig-core/src/middlewares/components.rs +++ b/rig-core/src/middlewares/components.rs @@ -7,6 +7,7 @@ use crate::{ message::{AssistantContent, Text}, }; +#[derive(Clone)] pub struct AwaitApprovalLayer { integration: T, } @@ -43,6 +44,7 @@ where } } +#[derive(Clone)] pub struct AwaitApprovalLayerService { inner: S, integration: T, diff --git a/rig-core/src/middlewares/extractor.rs b/rig-core/src/middlewares/extractor.rs index 6d3de9d..ddd0a29 100644 --- a/rig-core/src/middlewares/extractor.rs +++ b/rig-core/src/middlewares/extractor.rs @@ -9,6 +9,7 @@ use crate::{ use super::ServiceError; +#[derive(Clone)] pub struct ExtractorLayer { _t: PhantomData, } @@ -41,6 +42,7 @@ where } } +#[derive(Clone)] pub struct ExtractorLayerService { inner: S, _t: PhantomData, diff --git a/rig-core/src/middlewares/tools.rs b/rig-core/src/middlewares/tools.rs index e0ab19a..09c9404 100644 --- a/rig-core/src/middlewares/tools.rs +++ b/rig-core/src/middlewares/tools.rs @@ -8,6 +8,7 @@ use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; use tower::{Layer, Service}; +#[derive(Clone)] pub struct ToolLayer { tools: Arc, } @@ -31,6 +32,7 @@ impl Layer for ToolLayer { } } +#[derive(Clone)] pub struct ToolLayerService { inner: S, tools: Arc,