mirror of https://github.com/0xplaygrounds/rig
feat: Add `rig_tool` macro (#353)
* feat: rig-macros crates providing a rig_tool proc macro * Support non-async functions as tools * Add static instance of generated tool * Make descriptions optional, add static instance of tool, support sync functions. * Add examples * Fixing Cargo.toml merge issues * Fix examples, replace use of schemars * Remove excessive comments and fix broken test * Move rig-macros contents into rig-core-derive * Doc comment + examples * Make doc linter happy * Simplify usage of serde_json, remove re-export
This commit is contained in:
parent
d10d1cc73b
commit
c7d4851e32
|
@ -8699,10 +8699,16 @@ dependencies = [
|
||||||
name = "rig-derive"
|
name = "rig-derive"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"convert_case 0.6.0",
|
||||||
"indoc",
|
"indoc",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
"rig-core 0.11.0",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"syn 2.0.100",
|
"syn 2.0.100",
|
||||||
|
"tokio",
|
||||||
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -7,10 +7,35 @@ description = "Internal crate that implements Rig derive macros."
|
||||||
repository = "https://github.com/0xPlaygrounds/rig"
|
repository = "https://github.com/0xPlaygrounds/rig"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
convert_case = { version = "0.6.0" }
|
||||||
indoc = "2.0.5"
|
indoc = "2.0.5"
|
||||||
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
||||||
quote = "1.0.37"
|
quote = "1.0.37"
|
||||||
|
serde_json = "1.0.108"
|
||||||
syn = { version = "2.0.79", features = ["full"]}
|
syn = { version = "2.0.79", features = ["full"]}
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rig-core = { path = "../../rig-core" }
|
||||||
|
serde = "1.0"
|
||||||
|
serde_json = "1.0.108"
|
||||||
|
tokio = { version = "1.44.0", features = ["full"] }
|
||||||
|
tracing-subscriber = "0.3.0"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "simple"
|
||||||
|
path = "examples/rig_tool/simple.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "with_description"
|
||||||
|
path = "examples/rig_tool/with_description.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "full"
|
||||||
|
path = "examples/rig_tool/full.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "async_tool"
|
||||||
|
path = "examples/rig_tool/async_tool.rs"
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example demonstrating async tool usage
|
||||||
|
#[rig_tool(
|
||||||
|
description = "A tool that simulates an async operation",
|
||||||
|
params(
|
||||||
|
input = "Input value to process",
|
||||||
|
delay_ms = "Delay in milliseconds before returning result"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn async_operation(input: String, delay_ms: u64) -> Result<String, rig::tool::ToolError> {
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||||
|
|
||||||
|
Ok(format!(
|
||||||
|
"Processed after {}ms: {}",
|
||||||
|
delay_ms,
|
||||||
|
input.to_uppercase()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let async_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(AsyncOperation)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"ASYNCOPERATION: {}",
|
||||||
|
serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Process the text 'hello world' with a delay of 1000ms",
|
||||||
|
"Process the text 'async operation' with a delay of 500ms",
|
||||||
|
"Process the text 'concurrent calls' with a delay of 200ms",
|
||||||
|
"Process the text 'error handling' with a delay of 'not a number'",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", async_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example with full attributes including parameter descriptions
|
||||||
|
#[rig_tool(
|
||||||
|
description = "A tool that performs string operations",
|
||||||
|
params(
|
||||||
|
text = "The input text to process",
|
||||||
|
operation = "The operation to perform (uppercase, lowercase, reverse)",
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||||
|
let result = match operation.as_str() {
|
||||||
|
"uppercase" => text.to_uppercase(),
|
||||||
|
"lowercase" => text.to_lowercase(),
|
||||||
|
"reverse" => text.chars().rev().collect(),
|
||||||
|
_ => {
|
||||||
|
return Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let string_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(StringProcessor)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"STRINGPROCESSOR: {}",
|
||||||
|
serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Convert 'hello world' to uppercase",
|
||||||
|
"Convert 'HELLO WORLD' to lowercase",
|
||||||
|
"Reverse the string 'hello world'",
|
||||||
|
"Convert 'hello world' to uppercase and repeat it 3 times",
|
||||||
|
"Perform an invalid operation on 'hello world'",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", string_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Simple example with no attributes
|
||||||
|
#[rig_tool]
|
||||||
|
fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn subtract(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a - b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn multiply(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a * b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn divide(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
if b == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(a / b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn answer_secret_question() -> Result<(bool, bool, bool, bool, bool), rig::tool::ToolError> {
|
||||||
|
Ok((false, false, true, false, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn how_many_rs(s: String) -> Result<usize, rig::tool::ToolError> {
|
||||||
|
Ok(s.chars()
|
||||||
|
.filter(|c| *c == 'r' || *c == 'R')
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn sum_numbers(numbers: Vec<i64>) -> Result<i64, rig::tool::ToolError> {
|
||||||
|
Ok(numbers.iter().sum())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let calculator_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Add)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Calculate 5 + 3",
|
||||||
|
"What is 10 + 20?",
|
||||||
|
"Add 100 and 200",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example with description attribute
|
||||||
|
#[rig_tool(description = "Perform basic arithmetic operations")]
|
||||||
|
fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let calculator_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Calculator)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"CALCULATOR: {}",
|
||||||
|
serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Calculate 5 + 3",
|
||||||
|
"What is 10 - 4?",
|
||||||
|
"Multiply 6 and 7",
|
||||||
|
"Divide 20 by 5",
|
||||||
|
"What is 10 / 0?",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,15 @@
|
||||||
extern crate proc_macro;
|
extern crate proc_macro;
|
||||||
|
|
||||||
|
use convert_case::{Case, Casing};
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use syn::{parse_macro_input, DeriveInput};
|
use quote::{format_ident, quote};
|
||||||
|
use std::{collections::HashMap, ops::Deref};
|
||||||
|
use syn::{
|
||||||
|
parse::{Parse, ParseStream},
|
||||||
|
parse_macro_input,
|
||||||
|
punctuated::Punctuated,
|
||||||
|
DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
|
||||||
|
};
|
||||||
|
|
||||||
mod basic;
|
mod basic;
|
||||||
mod custom;
|
mod custom;
|
||||||
|
@ -19,3 +28,311 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
|
||||||
.unwrap_or_else(syn::Error::into_compile_error)
|
.unwrap_or_else(syn::Error::into_compile_error)
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MacroArgs {
|
||||||
|
description: Option<String>,
|
||||||
|
param_descriptions: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for MacroArgs {
|
||||||
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||||
|
let mut description = None;
|
||||||
|
let mut param_descriptions = HashMap::new();
|
||||||
|
|
||||||
|
// If the input is empty, return default values
|
||||||
|
if input.is_empty() {
|
||||||
|
return Ok(MacroArgs {
|
||||||
|
description,
|
||||||
|
param_descriptions,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
|
||||||
|
|
||||||
|
for meta in meta_list {
|
||||||
|
match meta {
|
||||||
|
Meta::NameValue(nv) => {
|
||||||
|
let ident = nv.path.get_ident().unwrap().to_string();
|
||||||
|
if let Expr::Lit(ExprLit {
|
||||||
|
lit: Lit::Str(lit_str),
|
||||||
|
..
|
||||||
|
}) = nv.value
|
||||||
|
{
|
||||||
|
if ident.as_str() == "description" {
|
||||||
|
description = Some(lit_str.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Meta::List(list) if list.path.is_ident("params") => {
|
||||||
|
let nested: Punctuated<Meta, Token![,]> =
|
||||||
|
list.parse_args_with(Punctuated::parse_terminated)?;
|
||||||
|
|
||||||
|
for meta in nested {
|
||||||
|
if let Meta::NameValue(nv) = meta {
|
||||||
|
if let Expr::Lit(ExprLit {
|
||||||
|
lit: Lit::Str(lit_str),
|
||||||
|
..
|
||||||
|
}) = nv.value
|
||||||
|
{
|
||||||
|
let param_name = nv.path.get_ident().unwrap().to_string();
|
||||||
|
param_descriptions.insert(param_name, lit_str.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(MacroArgs {
|
||||||
|
description,
|
||||||
|
param_descriptions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
|
||||||
|
match ty {
|
||||||
|
Type::Path(type_path) => {
|
||||||
|
let segment = &type_path.path.segments[0];
|
||||||
|
let type_name = segment.ident.to_string();
|
||||||
|
|
||||||
|
// Handle Vec types
|
||||||
|
if type_name == "Vec" {
|
||||||
|
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
|
||||||
|
if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
|
||||||
|
let inner_json_type = get_json_type(inner_type);
|
||||||
|
return quote! {
|
||||||
|
"type": "array",
|
||||||
|
"items": { #inner_json_type }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quote! { "type": "array" };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle primitive types
|
||||||
|
match type_name.as_str() {
|
||||||
|
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
|
||||||
|
quote! { "type": "number" }
|
||||||
|
}
|
||||||
|
"String" | "str" => {
|
||||||
|
quote! { "type": "string" }
|
||||||
|
}
|
||||||
|
"bool" => {
|
||||||
|
quote! { "type": "boolean" }
|
||||||
|
}
|
||||||
|
// Handle other types as objects
|
||||||
|
_ => {
|
||||||
|
quote! { "type": "object" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
quote! { "type": "object" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// Basic usage:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool]
|
||||||
|
/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
/// Ok(a + b)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// With description:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool(description = "Perform basic arithmetic operations")]
|
||||||
|
/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
/// match operation.as_str() {
|
||||||
|
/// "add" => Ok(x + y),
|
||||||
|
/// "subtract" => Ok(x - y),
|
||||||
|
/// "multiply" => Ok(x * y),
|
||||||
|
/// "divide" => Ok(x / y),
|
||||||
|
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// With parameter descriptions:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool(
|
||||||
|
/// description = "A tool that performs string operations",
|
||||||
|
/// params(
|
||||||
|
/// text = "The input text to process",
|
||||||
|
/// operation = "The operation to perform (uppercase, lowercase, reverse)"
|
||||||
|
/// )
|
||||||
|
/// )]
|
||||||
|
/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||||
|
/// match operation.as_str() {
|
||||||
|
/// "uppercase" => Ok(text.to_uppercase()),
|
||||||
|
/// "lowercase" => Ok(text.to_lowercase()),
|
||||||
|
/// "reverse" => Ok(text.chars().rev().collect()),
|
||||||
|
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
|
let args = parse_macro_input!(args as MacroArgs);
|
||||||
|
let input_fn = parse_macro_input!(input as syn::ItemFn);
|
||||||
|
|
||||||
|
// Extract function details
|
||||||
|
let fn_name = &input_fn.sig.ident;
|
||||||
|
let fn_name_str = fn_name.to_string();
|
||||||
|
let is_async = input_fn.sig.asyncness.is_some();
|
||||||
|
|
||||||
|
// Extract return type and get Output and Error types from Result<T, E>
|
||||||
|
let return_type = &input_fn.sig.output;
|
||||||
|
let output_type = match return_type {
|
||||||
|
ReturnType::Type(_, ty) => {
|
||||||
|
if let Type::Path(type_path) = ty.deref() {
|
||||||
|
if let Some(last_segment) = type_path.path.segments.last() {
|
||||||
|
if last_segment.ident == "Result" {
|
||||||
|
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
|
||||||
|
if args.args.len() == 2 {
|
||||||
|
let output = args.args.first().unwrap();
|
||||||
|
let error = args.args.last().unwrap();
|
||||||
|
|
||||||
|
// Convert the error type to a string for comparison
|
||||||
|
let error_str = quote!(#error).to_string().replace(" ", "");
|
||||||
|
if !error_str.contains("rig::tool::ToolError") {
|
||||||
|
panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
quote!(#output)
|
||||||
|
} else {
|
||||||
|
panic!("Expected Result with two type parameters");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Expected angle bracketed type parameters for Result");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Return type must be a Result");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Invalid return type");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Invalid return type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Function must have a return type"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate PascalCase struct name from the function name
|
||||||
|
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
|
||||||
|
|
||||||
|
// Use provided description or generate a default one
|
||||||
|
let tool_description = match args.description {
|
||||||
|
Some(desc) => quote! { #desc.to_string() },
|
||||||
|
None => quote! { format!("Function to {}", Self::NAME) },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract parameter names, types, and descriptions
|
||||||
|
let mut param_names = Vec::new();
|
||||||
|
let mut param_types = Vec::new();
|
||||||
|
let mut param_descriptions = Vec::new();
|
||||||
|
let mut json_types = Vec::new();
|
||||||
|
|
||||||
|
for arg in input_fn.sig.inputs.iter() {
|
||||||
|
if let syn::FnArg::Typed(pat_type) = arg {
|
||||||
|
if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
|
||||||
|
let param_name = ¶m_ident.ident;
|
||||||
|
let param_name_str = param_name.to_string();
|
||||||
|
let ty = &pat_type.ty;
|
||||||
|
let default_parameter_description = format!("Parameter {}", param_name_str);
|
||||||
|
let description = args
|
||||||
|
.param_descriptions
|
||||||
|
.get(¶m_name_str)
|
||||||
|
.map(|s| s.to_owned())
|
||||||
|
.unwrap_or(default_parameter_description);
|
||||||
|
|
||||||
|
param_names.push(param_name);
|
||||||
|
param_types.push(ty);
|
||||||
|
param_descriptions.push(description);
|
||||||
|
json_types.push(get_json_type(ty));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let params_struct_name = format_ident!("{}Parameters", struct_name);
|
||||||
|
let static_name = format_ident!("{}", fn_name_str.to_uppercase());
|
||||||
|
|
||||||
|
// Generate the call implementation based on whether the function is async
|
||||||
|
let call_impl = if is_async {
|
||||||
|
quote! {
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let expanded = quote! {
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct #params_struct_name {
|
||||||
|
#(#param_names: #param_types,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
#input_fn
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub(crate) struct #struct_name;
|
||||||
|
|
||||||
|
impl rig::tool::Tool for #struct_name {
|
||||||
|
const NAME: &'static str = #fn_name_str;
|
||||||
|
|
||||||
|
type Args = #params_struct_name;
|
||||||
|
type Output = #output_type;
|
||||||
|
type Error = rig::tool::ToolError;
|
||||||
|
|
||||||
|
fn name(&self) -> String {
|
||||||
|
#fn_name_str.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
|
||||||
|
let parameters = serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
#(
|
||||||
|
stringify!(#param_names): {
|
||||||
|
#json_types,
|
||||||
|
"description": #param_descriptions
|
||||||
|
}
|
||||||
|
),*
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
rig::completion::ToolDefinition {
|
||||||
|
name: #fn_name_str.to_string(),
|
||||||
|
description: #tool_description.to_string(),
|
||||||
|
parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#call_impl
|
||||||
|
}
|
||||||
|
|
||||||
|
pub static #static_name: #struct_name = #struct_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
TokenStream::from(expanded)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
|
||||||
|
#[rig_tool(
|
||||||
|
description = "Perform basic arithmetic operations",
|
||||||
|
params(
|
||||||
|
x = "First number in the calculation",
|
||||||
|
y = "Second number in the calculation",
|
||||||
|
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool(
|
||||||
|
description = "Perform basic arithmetic operations",
|
||||||
|
params(
|
||||||
|
x = "First number in the calculation",
|
||||||
|
y = "Second number in the calculation",
|
||||||
|
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
fn sync_calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_calculator_tool() {
|
||||||
|
// Create an instance of our tool
|
||||||
|
let calculator = Calculator::default();
|
||||||
|
|
||||||
|
// Test tool information
|
||||||
|
let definition = calculator.definition(String::default()).await;
|
||||||
|
println!("{:?}", definition);
|
||||||
|
assert_eq!(calculator.name(), "calculator");
|
||||||
|
assert_eq!(
|
||||||
|
definition.description,
|
||||||
|
"Perform basic arithmetic operations"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test valid operations
|
||||||
|
let test_cases = vec![
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "add".to_string(),
|
||||||
|
},
|
||||||
|
8,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "subtract".to_string(),
|
||||||
|
},
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "multiply".to_string(),
|
||||||
|
},
|
||||||
|
15,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 6,
|
||||||
|
y: 2,
|
||||||
|
operation: "divide".to_string(),
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (input, expected) in test_cases {
|
||||||
|
let result = calculator.call(input).await.unwrap();
|
||||||
|
assert_eq!(result, serde_json::json!(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test division by zero
|
||||||
|
let div_zero = CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 0,
|
||||||
|
operation: "divide".to_string(),
|
||||||
|
};
|
||||||
|
let err = calculator.call(div_zero).await.unwrap_err();
|
||||||
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
|
// Test invalid operation
|
||||||
|
let invalid_op = CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "power".to_string(),
|
||||||
|
};
|
||||||
|
let err = calculator.call(invalid_op).await.unwrap_err();
|
||||||
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
|
// Test sync calculator
|
||||||
|
let sync_calculator = SyncCalculator::default();
|
||||||
|
let result = sync_calculator
|
||||||
|
.call(SyncCalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "add".to_string(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result, serde_json::json!(8));
|
||||||
|
}
|
Loading…
Reference in New Issue