mirror of https://github.com/0xplaygrounds/rig
Simplify usage of serde_json, remove re-export
This commit is contained in:
parent
9cb7a38cb4
commit
78fa6e5061
|
@ -8705,6 +8705,7 @@ dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"rig-core 0.11.0",
|
"rig-core 0.11.0",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"syn 2.0.100",
|
"syn 2.0.100",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
|
|
@ -11,6 +11,7 @@ convert_case = { version = "0.6.0" }
|
||||||
indoc = "2.0.5"
|
indoc = "2.0.5"
|
||||||
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
||||||
quote = "1.0.37"
|
quote = "1.0.37"
|
||||||
|
serde_json = "1.0.108"
|
||||||
syn = { version = "2.0.79", features = ["full"]}
|
syn = { version = "2.0.79", features = ["full"]}
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
@ -19,6 +20,7 @@ proc-macro = true
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rig-core = { path = "../../rig-core" }
|
rig-core = { path = "../../rig-core" }
|
||||||
serde = "1.0"
|
serde = "1.0"
|
||||||
|
serde_json = "1.0.108"
|
||||||
tokio = { version = "1.44.0", features = ["full"] }
|
tokio = { version = "1.44.0", features = ["full"] }
|
||||||
tracing-subscriber = "0.3.0"
|
tracing-subscriber = "0.3.0"
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use rig::completion::Prompt;
|
use rig::completion::Prompt;
|
||||||
use rig::providers;
|
use rig::providers;
|
||||||
use rig::serde_json;
|
|
||||||
use rig::tool::Tool;
|
use rig::tool::Tool;
|
||||||
use rig_derive::rig_tool;
|
use rig_derive::rig_tool;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use rig::completion::Prompt;
|
use rig::completion::Prompt;
|
||||||
use rig::providers;
|
use rig::providers;
|
||||||
use rig::serde_json;
|
|
||||||
use rig::tool::Tool;
|
use rig::tool::Tool;
|
||||||
use rig_derive::rig_tool;
|
use rig_derive::rig_tool;
|
||||||
use tracing_subscriber;
|
use tracing_subscriber;
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use rig::completion::Prompt;
|
use rig::completion::Prompt;
|
||||||
use rig::providers;
|
use rig::providers;
|
||||||
use rig::serde_json;
|
|
||||||
use rig::tool::Tool;
|
use rig::tool::Tool;
|
||||||
use rig_derive::rig_tool;
|
use rig_derive::rig_tool;
|
||||||
use tracing_subscriber;
|
use tracing_subscriber;
|
||||||
|
|
|
@ -275,21 +275,13 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
let call_impl = if is_async {
|
let call_impl = if is_async {
|
||||||
quote! {
|
quote! {
|
||||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
// Extract parameters and call the function
|
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
|
|
||||||
let result = #fn_name(#(params.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
quote! {
|
quote! {
|
||||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
// Extract parameters and call the function
|
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
let params: #params_struct_name = rig::serde_json::from_value(args).map_err(|e| rig::tool::ToolError::JsonError(e.into()))?;
|
|
||||||
let result = #fn_name(#(params.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -308,7 +300,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
impl rig::tool::Tool for #struct_name {
|
impl rig::tool::Tool for #struct_name {
|
||||||
const NAME: &'static str = #fn_name_str;
|
const NAME: &'static str = #fn_name_str;
|
||||||
|
|
||||||
type Args = rig::serde_json::Value;
|
type Args = #params_struct_name;
|
||||||
type Output = #output_type;
|
type Output = #output_type;
|
||||||
type Error = rig::tool::ToolError;
|
type Error = rig::tool::ToolError;
|
||||||
|
|
||||||
|
@ -317,7 +309,7 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
|
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
|
||||||
let parameters = rig::serde_json::json!({
|
let parameters = serde_json::json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
#(
|
#(
|
||||||
|
|
|
@ -74,72 +74,72 @@ async fn test_calculator_tool() {
|
||||||
// Test valid operations
|
// Test valid operations
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
(
|
(
|
||||||
rig::serde_json::json!({
|
CalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 3,
|
y: 3,
|
||||||
"operation": "add"
|
operation: "add".to_string(),
|
||||||
}),
|
},
|
||||||
8,
|
8,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
rig::serde_json::json!({
|
CalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 3,
|
y: 3,
|
||||||
"operation": "subtract"
|
operation: "subtract".to_string(),
|
||||||
}),
|
},
|
||||||
2,
|
2,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
rig::serde_json::json!({
|
CalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 3,
|
y: 3,
|
||||||
"operation": "multiply"
|
operation: "multiply".to_string(),
|
||||||
}),
|
},
|
||||||
15,
|
15,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
rig::serde_json::json!({
|
CalculatorParameters {
|
||||||
"x": 6,
|
x: 6,
|
||||||
"y": 2,
|
y: 2,
|
||||||
"operation": "divide"
|
operation: "divide".to_string(),
|
||||||
}),
|
},
|
||||||
3,
|
3,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
for (input, expected) in test_cases {
|
for (input, expected) in test_cases {
|
||||||
let result = calculator.call(input).await.unwrap();
|
let result = calculator.call(input).await.unwrap();
|
||||||
assert_eq!(result, rig::serde_json::json!(expected));
|
assert_eq!(result, serde_json::json!(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test division by zero
|
// Test division by zero
|
||||||
let div_zero = rig::serde_json::json!({
|
let div_zero = CalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 0,
|
y: 0,
|
||||||
"operation": "divide"
|
operation: "divide".to_string(),
|
||||||
});
|
};
|
||||||
let err = calculator.call(div_zero).await.unwrap_err();
|
let err = calculator.call(div_zero).await.unwrap_err();
|
||||||
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
// Test invalid operation
|
// Test invalid operation
|
||||||
let invalid_op = rig::serde_json::json!({
|
let invalid_op = CalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 3,
|
y: 3,
|
||||||
"operation": "power"
|
operation: "power".to_string(),
|
||||||
});
|
};
|
||||||
let err = calculator.call(invalid_op).await.unwrap_err();
|
let err = calculator.call(invalid_op).await.unwrap_err();
|
||||||
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
// Test sync calculator
|
// Test sync calculator
|
||||||
let sync_calculator = SyncCalculator::default();
|
let sync_calculator = SyncCalculator::default();
|
||||||
let result = sync_calculator
|
let result = sync_calculator
|
||||||
.call(rig::serde_json::json!({
|
.call(SyncCalculatorParameters {
|
||||||
"x": 5,
|
x: 5,
|
||||||
"y": 3,
|
y: 3,
|
||||||
"operation": "add"
|
operation: "add".to_string(),
|
||||||
}))
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(result, rig::serde_json::json!(8));
|
assert_eq!(result, serde_json::json!(8));
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,7 +98,6 @@ pub mod streaming;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
pub mod transcription;
|
pub mod transcription;
|
||||||
pub mod vector_store;
|
pub mod vector_store;
|
||||||
pub use serde_json;
|
|
||||||
|
|
||||||
// Re-export commonly used types and traits
|
// Re-export commonly used types and traits
|
||||||
pub use completion::message;
|
pub use completion::message;
|
||||||
|
|
Loading…
Reference in New Issue