Simplify usage of serde_json, remove re-export

This commit is contained in:
Collin Brittain 2025-04-10 10:45:47 -05:00
parent 9cb7a38cb4
commit 78fa6e5061
8 changed files with 44 additions and 53 deletions

1
Cargo.lock generated
View File

@ -8705,6 +8705,7 @@ dependencies = [
"quote",
"rig-core 0.11.0",
"serde",
"serde_json",
"syn 2.0.100",
"tokio",
"tracing-subscriber",

View File

@ -11,6 +11,7 @@ convert_case = { version = "0.6.0" }
indoc = "2.0.5"
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
quote = "1.0.37"
serde_json = "1.0.108"
syn = { version = "2.0.79", features = ["full"]}
[lib]
@ -19,6 +20,7 @@ proc-macro = true
[dev-dependencies]
rig-core = { path = "../../rig-core" }
serde = "1.0"
serde_json = "1.0.108"
tokio = { version = "1.44.0", features = ["full"] }
tracing-subscriber = "0.3.0"

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt;
use rig::providers;
use rig::serde_json;
use rig::tool::Tool;
use rig_derive::rig_tool;
use std::time::Duration;

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt;
use rig::providers;
use rig::serde_json;
use rig::tool::Tool;
use rig_derive::rig_tool;
use tracing_subscriber;

View File

@ -1,6 +1,5 @@
use rig::completion::Prompt;
use rig::providers;
use rig::serde_json;
use rig::tool::Tool;
use rig_derive::rig_tool;
use tracing_subscriber;

View File

@ -275,21 +275,13 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
let call_impl = if is_async {
quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
// Extract parameters and call the function
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
let result = #fn_name(#(params.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
Ok(result)
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
}
}
} else {
quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
// Extract parameters and call the function
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
let result = #fn_name(#(params.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
Ok(result)
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
}
}
};
@ -308,7 +300,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
impl rig::tool::Tool for #struct_name {
const NAME: &'static str = #fn_name_str;
type Args = rig::serde_json::Value;
type Args = #params_struct_name;
type Output = #output_type;
type Error = rig::tool::ToolError;
@ -317,7 +309,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
}
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
let parameters = rig::serde_json::json!({
let parameters = serde_json::json!({
"type": "object",
"properties": {
#(

View File

@ -74,72 +74,72 @@ async fn test_calculator_tool() {
// Test valid operations
let test_cases = vec![
(
rig::serde_json::json!({
"x": 5,
"y": 3,
"operation": "add"
}),
CalculatorParameters {
x: 5,
y: 3,
operation: "add".to_string(),
},
8,
),
(
rig::serde_json::json!({
"x": 5,
"y": 3,
"operation": "subtract"
}),
CalculatorParameters {
x: 5,
y: 3,
operation: "subtract".to_string(),
},
2,
),
(
rig::serde_json::json!({
"x": 5,
"y": 3,
"operation": "multiply"
}),
CalculatorParameters {
x: 5,
y: 3,
operation: "multiply".to_string(),
},
15,
),
(
rig::serde_json::json!({
"x": 6,
"y": 2,
"operation": "divide"
}),
CalculatorParameters {
x: 6,
y: 2,
operation: "divide".to_string(),
},
3,
),
];
for (input, expected) in test_cases {
let result = calculator.call(input).await.unwrap();
assert_eq!(result, rig::serde_json::json!(expected));
assert_eq!(result, serde_json::json!(expected));
}
// Test division by zero
let div_zero = rig::serde_json::json!({
"x": 5,
"y": 0,
"operation": "divide"
});
let div_zero = CalculatorParameters {
x: 5,
y: 0,
operation: "divide".to_string(),
};
let err = calculator.call(div_zero).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test invalid operation
let invalid_op = rig::serde_json::json!({
"x": 5,
"y": 3,
"operation": "power"
});
let invalid_op = CalculatorParameters {
x: 5,
y: 3,
operation: "power".to_string(),
};
let err = calculator.call(invalid_op).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test sync calculator
let sync_calculator = SyncCalculator::default();
let result = sync_calculator
.call(rig::serde_json::json!({
"x": 5,
"y": 3,
"operation": "add"
}))
.call(SyncCalculatorParameters {
x: 5,
y: 3,
operation: "add".to_string(),
})
.await
.unwrap();
assert_eq!(result, rig::serde_json::json!(8));
assert_eq!(result, serde_json::json!(8));
}

View File

@ -98,7 +98,6 @@ pub mod streaming;
pub mod tool;
pub mod transcription;
pub mod vector_store;
pub use serde_json;
// Re-export commonly used types and traits
pub use completion::message;