Simplify usage of serde_json, remove re-export

This commit is contained in:
Collin Brittain 2025-04-10 10:45:47 -05:00
parent 9cb7a38cb4
commit 78fa6e5061
8 changed files with 44 additions and 53 deletions

1
Cargo.lock generated
View File

@ -8705,6 +8705,7 @@ dependencies = [
"quote", "quote",
"rig-core 0.11.0", "rig-core 0.11.0",
"serde", "serde",
"serde_json",
"syn 2.0.100", "syn 2.0.100",
"tokio", "tokio",
"tracing-subscriber", "tracing-subscriber",

View File

@ -11,6 +11,7 @@ convert_case = { version = "0.6.0" }
indoc = "2.0.5" indoc = "2.0.5"
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
quote = "1.0.37" quote = "1.0.37"
serde_json = "1.0.108"
syn = { version = "2.0.79", features = ["full"]} syn = { version = "2.0.79", features = ["full"]}
[lib] [lib]
@ -19,6 +20,7 @@ proc-macro = true
[dev-dependencies] [dev-dependencies]
rig-core = { path = "../../rig-core" } rig-core = { path = "../../rig-core" }
serde = "1.0" serde = "1.0"
serde_json = "1.0.108"
tokio = { version = "1.44.0", features = ["full"] } tokio = { version = "1.44.0", features = ["full"] }
tracing-subscriber = "0.3.0" tracing-subscriber = "0.3.0"

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt; use rig::completion::Prompt;
use rig::providers; use rig::providers;
use rig::serde_json;
use rig::tool::Tool; use rig::tool::Tool;
use rig_derive::rig_tool; use rig_derive::rig_tool;
use std::time::Duration; use std::time::Duration;

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt; use rig::completion::Prompt;
use rig::providers; use rig::providers;
use rig::serde_json;
use rig::tool::Tool; use rig::tool::Tool;
use rig_derive::rig_tool; use rig_derive::rig_tool;
use tracing_subscriber; use tracing_subscriber;

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt; use rig::completion::Prompt;
use rig::providers; use rig::providers;
use rig::serde_json;
use rig::tool::Tool; use rig::tool::Tool;
use rig_derive::rig_tool; use rig_derive::rig_tool;
use tracing_subscriber; use tracing_subscriber;

View File

@ -275,21 +275,13 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
let call_impl = if is_async { let call_impl = if is_async {
quote! { quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
// Extract parameters and call the function #fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
let result = #fn_name(#(params.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
Ok(result)
} }
} }
} else { } else {
quote! { quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
// Extract parameters and call the function #fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
let result = #fn_name(#(params.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
Ok(result)
} }
} }
}; };
@ -308,7 +300,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
impl rig::tool::Tool for #struct_name { impl rig::tool::Tool for #struct_name {
const NAME: &'static str = #fn_name_str; const NAME: &'static str = #fn_name_str;
type Args = rig::serde_json::Value; type Args = #params_struct_name;
type Output = #output_type; type Output = #output_type;
type Error = rig::tool::ToolError; type Error = rig::tool::ToolError;
@ -317,7 +309,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
} }
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition { async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
let parameters = rig::serde_json::json!({ let parameters = serde_json::json!({
"type": "object", "type": "object",
"properties": { "properties": {
#( #(

View File

@ -74,72 +74,72 @@ async fn test_calculator_tool() {
// Test valid operations // Test valid operations
let test_cases = vec![ let test_cases = vec![
( (
rig::serde_json::json!({ CalculatorParameters {
"x": 5, x: 5,
"y": 3, y: 3,
"operation": "add" operation: "add".to_string(),
}), },
8, 8,
), ),
( (
rig::serde_json::json!({ CalculatorParameters {
"x": 5, x: 5,
"y": 3, y: 3,
"operation": "subtract" operation: "subtract".to_string(),
}), },
2, 2,
), ),
( (
rig::serde_json::json!({ CalculatorParameters {
"x": 5, x: 5,
"y": 3, y: 3,
"operation": "multiply" operation: "multiply".to_string(),
}), },
15, 15,
), ),
( (
rig::serde_json::json!({ CalculatorParameters {
"x": 6, x: 6,
"y": 2, y: 2,
"operation": "divide" operation: "divide".to_string(),
}), },
3, 3,
), ),
]; ];
for (input, expected) in test_cases { for (input, expected) in test_cases {
let result = calculator.call(input).await.unwrap(); let result = calculator.call(input).await.unwrap();
assert_eq!(result, rig::serde_json::json!(expected)); assert_eq!(result, serde_json::json!(expected));
} }
// Test division by zero // Test division by zero
let div_zero = rig::serde_json::json!({ let div_zero = CalculatorParameters {
"x": 5, x: 5,
"y": 0, y: 0,
"operation": "divide" operation: "divide".to_string(),
}); };
let err = calculator.call(div_zero).await.unwrap_err(); let err = calculator.call(div_zero).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test invalid operation // Test invalid operation
let invalid_op = rig::serde_json::json!({ let invalid_op = CalculatorParameters {
"x": 5, x: 5,
"y": 3, y: 3,
"operation": "power" operation: "power".to_string(),
}); };
let err = calculator.call(invalid_op).await.unwrap_err(); let err = calculator.call(invalid_op).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test sync calculator // Test sync calculator
let sync_calculator = SyncCalculator::default(); let sync_calculator = SyncCalculator::default();
let result = sync_calculator let result = sync_calculator
.call(rig::serde_json::json!({ .call(SyncCalculatorParameters {
"x": 5, x: 5,
"y": 3, y: 3,
"operation": "add" operation: "add".to_string(),
})) })
.await .await
.unwrap(); .unwrap();
assert_eq!(result, rig::serde_json::json!(8)); assert_eq!(result, serde_json::json!(8));
} }

View File

@ -98,7 +98,6 @@ pub mod streaming;
pub mod tool; pub mod tool;
pub mod transcription; pub mod transcription;
pub mod vector_store; pub mod vector_store;
pub use serde_json;
// Re-export commonly used types and traits // Re-export commonly used types and traits
pub use completion::message; pub use completion::message;