From 78fa6e50615d2ca83786e451d052e250cc16cae0 Mon Sep 17 00:00:00 2001 From: Collin Brittain Date: Thu, 10 Apr 2025 10:45:47 -0500 Subject: [PATCH] Simplify usage of serde_json, remove re-export --- Cargo.lock | 1 + rig-core/rig-core-derive/Cargo.toml | 2 + .../examples/rig_tool/async_tool.rs | 1 - .../rig-core-derive/examples/rig_tool/full.rs | 1 - .../examples/rig_tool/with_description.rs | 1 - rig-core/rig-core-derive/src/lib.rs | 16 +--- rig-core/rig-core-derive/tests/calculator.rs | 74 +++++++++---------- rig-core/src/lib.rs | 1 - 8 files changed, 44 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc67731..5a7d564 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8705,6 +8705,7 @@ dependencies = [ "quote", "rig-core 0.11.0", "serde", + "serde_json", "syn 2.0.100", "tokio", "tracing-subscriber", diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index a829320..d5228ae 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -11,6 +11,7 @@ convert_case = { version = "0.6.0" } indoc = "2.0.5" proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } quote = "1.0.37" +serde_json = "1.0.108" syn = { version = "2.0.79", features = ["full"]} [lib] @@ -19,6 +20,7 @@ proc-macro = true [dev-dependencies] rig-core = { path = "../../rig-core" } serde = "1.0" +serde_json = "1.0.108" tokio = { version = "1.44.0", features = ["full"] } tracing-subscriber = "0.3.0" diff --git a/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs b/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs index e00c5c1..bc2fc5d 100644 --- a/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs +++ b/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs @@ -1,6 +1,5 @@ use rig::completion::Prompt; use rig::providers; -use rig::serde_json; use rig::tool::Tool; use rig_derive::rig_tool; use std::time::Duration; diff --git a/rig-core/rig-core-derive/examples/rig_tool/full.rs b/rig-core/rig-core-derive/examples/rig_tool/full.rs index 6321994..d012791 100644 --- a/rig-core/rig-core-derive/examples/rig_tool/full.rs +++ b/rig-core/rig-core-derive/examples/rig_tool/full.rs @@ -1,6 +1,5 @@ use rig::completion::Prompt; use rig::providers; -use rig::serde_json; use rig::tool::Tool; use rig_derive::rig_tool; use tracing_subscriber; diff --git a/rig-core/rig-core-derive/examples/rig_tool/with_description.rs b/rig-core/rig-core-derive/examples/rig_tool/with_description.rs index d05123a..9225bd3 100644 --- a/rig-core/rig-core-derive/examples/rig_tool/with_description.rs +++ b/rig-core/rig-core-derive/examples/rig_tool/with_description.rs @@ -1,6 +1,5 @@ use rig::completion::Prompt; use rig::providers; -use rig::serde_json; use rig::tool::Tool; use rig_derive::rig_tool; use tracing_subscriber; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 5a33faf..7d800bc 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -275,21 +275,13 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream { let call_impl = if is_async { quote! { async fn call(&self, args: Self::Args) -> Result { - // Extract parameters and call the function - let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?; - let result = #fn_name(#(params.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?; - - Ok(result) + #fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into())) } } } else { quote! { async fn call(&self, args: Self::Args) -> Result { - // Extract parameters and call the function - let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?; - let result = #fn_name(#(params.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?; - - Ok(result) + #fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into())) } } }; @@ -308,7 +300,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream { impl rig::tool::Tool for #struct_name { const NAME: &'static str = #fn_name_str; - type Args = rig::serde_json::Value; + type Args = #params_struct_name; type Output = #output_type; type Error = rig::tool::ToolError; @@ -317,7 +309,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream { } async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition { - let parameters = rig::serde_json::json!({ + let parameters = serde_json::json!({ "type": "object", "properties": { #( diff --git a/rig-core/rig-core-derive/tests/calculator.rs b/rig-core/rig-core-derive/tests/calculator.rs index 244eafb..97e6b8c 100644 --- a/rig-core/rig-core-derive/tests/calculator.rs +++ b/rig-core/rig-core-derive/tests/calculator.rs @@ -74,72 +74,72 @@ async fn test_calculator_tool() { // Test valid operations let test_cases = vec![ ( - rig::serde_json::json!({ - "x": 5, - "y": 3, - "operation": "add" - }), + CalculatorParameters { + x: 5, + y: 3, + operation: "add".to_string(), + }, 8, ), ( - rig::serde_json::json!({ - "x": 5, - "y": 3, - "operation": "subtract" - }), + CalculatorParameters { + x: 5, + y: 3, + operation: "subtract".to_string(), + }, 2, ), ( - rig::serde_json::json!({ - "x": 5, - "y": 3, - "operation": "multiply" - }), + CalculatorParameters { + x: 5, + y: 3, + operation: "multiply".to_string(), + }, 15, ), ( - rig::serde_json::json!({ - "x": 6, - "y": 2, - "operation": "divide" - }), + CalculatorParameters { + x: 6, + y: 2, + operation: "divide".to_string(), + }, 3, ), ]; for (input, expected) in test_cases { let result = calculator.call(input).await.unwrap(); - assert_eq!(result, rig::serde_json::json!(expected)); + assert_eq!(result, serde_json::json!(expected)); } // Test division by zero - let div_zero = rig::serde_json::json!({ - "x": 5, - "y": 0, - "operation": "divide" - }); + let div_zero = CalculatorParameters { + x: 5, + y: 0, + operation: "divide".to_string(), + }; let err = calculator.call(div_zero).await.unwrap_err(); assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); // Test invalid operation - let invalid_op = rig::serde_json::json!({ - "x": 5, - "y": 3, - "operation": "power" - }); + let invalid_op = CalculatorParameters { + x: 5, + y: 3, + operation: "power".to_string(), + }; let err = calculator.call(invalid_op).await.unwrap_err(); assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); // Test sync calculator let sync_calculator = SyncCalculator::default(); let result = sync_calculator - .call(rig::serde_json::json!({ - "x": 5, - "y": 3, - "operation": "add" - })) + .call(SyncCalculatorParameters { + x: 5, + y: 3, + operation: "add".to_string(), + }) .await .unwrap(); - assert_eq!(result, rig::serde_json::json!(8)); + assert_eq!(result, serde_json::json!(8)); } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 627667f..300c962 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -98,7 +98,6 @@ pub mod streaming; pub mod tool; pub mod transcription; pub mod vector_store; -pub use serde_json; // Re-export commonly used types and traits pub use completion::message;