Make descriptions optional, add static instance of tool, support sync functions.

This commit is contained in:
Collin Brittain 2025-04-07 11:20:50 -05:00
parent 7341e616fa
commit 0a9bbe52b7
2 changed files with 26 additions and 10 deletions

View File

@ -17,6 +17,14 @@ impl Parse for MacroArgs {
let mut description = None;
let mut param_descriptions = HashMap::new();
// If the input is empty, return default values
if input.is_empty() {
return Ok(MacroArgs {
description,
param_descriptions,
});
}
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
for meta in meta_list {
@ -28,8 +36,9 @@ impl Parse for MacroArgs {
..
}) = nv.value
{
if ident.as_str() == "description" {
description = Some(lit_str.value());
match ident.as_str() {
"description" => description = Some(lit_str.value()),
_ => {}
}
}
}
@ -112,8 +121,11 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
// Generate PascalCase struct name from the function name
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
// Use provided name or function name as default
let tool_description = args.description.unwrap_or_default();
// Use provided description or generate a default one
let tool_description = match args.description {
Some(desc) => quote! { #desc.to_string() },
None => quote! { format!("Function to {}", Self::NAME) },
};
// Extract parameter names, types, and descriptions
let mut param_defs = Vec::new();
@ -125,11 +137,12 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
let param_name = &param_ident.ident;
let param_name_str = param_name.to_string();
let ty = &pat_type.ty;
let default_parameter_description = format!("Parameter {}", param_name_str);
let description = args
.param_descriptions
.get(&param_name_str)
.map(|s| s.as_str())
.unwrap_or("");
.unwrap_or(&default_parameter_description);
param_names.push(param_name);
param_defs.push(quote! {

View File

@ -132,11 +132,14 @@ async fn test_calculator_tool() {
// Test sync calculator
let sync_calculator = SyncCalculator::default();
let result = sync_calculator.call(serde_json::json!({
let result = sync_calculator
.call(serde_json::json!({
"x": 5,
"y": 3,
"operation": "add"
})).await.unwrap();
}))
.await
.unwrap();
assert_eq!(result, serde_json::json!(8));
}