mirror of https://github.com/0xplaygrounds/rig
feat: Add `rig_tool` macro (#353)
* feat: rig-macros crates providing a rig_tool proc macro * Support non-async functions as tools * Add static instance of generated tool * Make descriptions optional, add static instance of tool, support sync functions. * Add examples * Fixing Cargo.toml merge issues * Fix examples, replace use of schemars * Remove excessive comments and fix broken test * Move rig-macros contents into rig-core-derive * Doc comment + examples * Make doc linter happy * Simplify usage of serde_json, remove re-export
This commit is contained in:
parent
d10d1cc73b
commit
c7d4851e32
|
@ -8699,10 +8699,16 @@ dependencies = [
|
|||
name = "rig-derive"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"convert_case 0.6.0",
|
||||
"indoc",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rig-core 0.11.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn 2.0.100",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -7,10 +7,35 @@ description = "Internal crate that implements Rig derive macros."
|
|||
repository = "https://github.com/0xPlaygrounds/rig"
|
||||
|
||||
[dependencies]
|
||||
convert_case = { version = "0.6.0" }
|
||||
indoc = "2.0.5"
|
||||
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
||||
quote = "1.0.37"
|
||||
serde_json = "1.0.108"
|
||||
syn = { version = "2.0.79", features = ["full"]}
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dev-dependencies]
|
||||
rig-core = { path = "../../rig-core" }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0.108"
|
||||
tokio = { version = "1.44.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.0"
|
||||
|
||||
[[example]]
|
||||
name = "simple"
|
||||
path = "examples/rig_tool/simple.rs"
|
||||
|
||||
[[example]]
|
||||
name = "with_description"
|
||||
path = "examples/rig_tool/with_description.rs"
|
||||
|
||||
[[example]]
|
||||
name = "full"
|
||||
path = "examples/rig_tool/full.rs"
|
||||
|
||||
[[example]]
|
||||
name = "async_tool"
|
||||
path = "examples/rig_tool/async_tool.rs"
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
use rig::completion::Prompt;
|
||||
use rig::providers;
|
||||
use rig::tool::Tool;
|
||||
use rig_derive::rig_tool;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber;
|
||||
|
||||
// Example demonstrating async tool usage
|
||||
#[rig_tool(
|
||||
description = "A tool that simulates an async operation",
|
||||
params(
|
||||
input = "Input value to process",
|
||||
delay_ms = "Delay in milliseconds before returning result"
|
||||
)
|
||||
)]
|
||||
async fn async_operation(input: String, delay_ms: u64) -> Result<String, rig::tool::ToolError> {
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
|
||||
Ok(format!(
|
||||
"Processed after {}ms: {}",
|
||||
delay_ms,
|
||||
input.to_uppercase()
|
||||
))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt().pretty().init();
|
||||
|
||||
let async_agent = providers::openai::Client::from_env()
|
||||
.agent(providers::openai::GPT_4O)
|
||||
.preamble("You are an agent with tools access, always use the tools")
|
||||
.max_tokens(1024)
|
||||
.tool(AsyncOperation)
|
||||
.build();
|
||||
|
||||
println!("Tool definition:");
|
||||
println!(
|
||||
"ASYNCOPERATION: {}",
|
||||
serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap()
|
||||
);
|
||||
|
||||
for prompt in [
|
||||
"What tools do you have?",
|
||||
"Process the text 'hello world' with a delay of 1000ms",
|
||||
"Process the text 'async operation' with a delay of 500ms",
|
||||
"Process the text 'concurrent calls' with a delay of 200ms",
|
||||
"Process the text 'error handling' with a delay of 'not a number'",
|
||||
] {
|
||||
println!("User: {}", prompt);
|
||||
println!("Agent: {}", async_agent.prompt(prompt).await.unwrap());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
use rig::completion::Prompt;
|
||||
use rig::providers;
|
||||
use rig::tool::Tool;
|
||||
use rig_derive::rig_tool;
|
||||
use tracing_subscriber;
|
||||
|
||||
// Example with full attributes including parameter descriptions
|
||||
#[rig_tool(
|
||||
description = "A tool that performs string operations",
|
||||
params(
|
||||
text = "The input text to process",
|
||||
operation = "The operation to perform (uppercase, lowercase, reverse)",
|
||||
)
|
||||
)]
|
||||
fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||
let result = match operation.as_str() {
|
||||
"uppercase" => text.to_uppercase(),
|
||||
"lowercase" => text.to_lowercase(),
|
||||
"reverse" => text.chars().rev().collect(),
|
||||
_ => {
|
||||
return Err(rig::tool::ToolError::ToolCallError(
|
||||
format!("Unknown operation: {}", operation).into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt().pretty().init();
|
||||
|
||||
let string_agent = providers::openai::Client::from_env()
|
||||
.agent(providers::openai::GPT_4O)
|
||||
.preamble("You are an agent with tools access, always use the tools")
|
||||
.max_tokens(1024)
|
||||
.tool(StringProcessor)
|
||||
.build();
|
||||
|
||||
println!("Tool definition:");
|
||||
println!(
|
||||
"STRINGPROCESSOR: {}",
|
||||
serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap()
|
||||
);
|
||||
|
||||
for prompt in [
|
||||
"What tools do you have?",
|
||||
"Convert 'hello world' to uppercase",
|
||||
"Convert 'HELLO WORLD' to lowercase",
|
||||
"Reverse the string 'hello world'",
|
||||
"Convert 'hello world' to uppercase and repeat it 3 times",
|
||||
"Perform an invalid operation on 'hello world'",
|
||||
] {
|
||||
println!("User: {}", prompt);
|
||||
println!("Agent: {}", string_agent.prompt(prompt).await.unwrap());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
use rig::completion::Prompt;
|
||||
use rig::providers;
|
||||
use rig_derive::rig_tool;
|
||||
use tracing_subscriber;
|
||||
|
||||
// Simple example with no attributes
|
||||
#[rig_tool]
|
||||
fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||
Ok(a + b)
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn subtract(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||
Ok(a - b)
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn multiply(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||
Ok(a * b)
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn divide(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||
if b == 0 {
|
||||
Err(rig::tool::ToolError::ToolCallError(
|
||||
"Division by zero".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(a / b)
|
||||
}
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn answer_secret_question() -> Result<(bool, bool, bool, bool, bool), rig::tool::ToolError> {
|
||||
Ok((false, false, true, false, false))
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn how_many_rs(s: String) -> Result<usize, rig::tool::ToolError> {
|
||||
Ok(s.chars()
|
||||
.filter(|c| *c == 'r' || *c == 'R')
|
||||
.collect::<Vec<_>>()
|
||||
.len())
|
||||
}
|
||||
|
||||
#[rig_tool]
|
||||
fn sum_numbers(numbers: Vec<i64>) -> Result<i64, rig::tool::ToolError> {
|
||||
Ok(numbers.iter().sum())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt().pretty().init();
|
||||
|
||||
let calculator_agent = providers::openai::Client::from_env()
|
||||
.agent(providers::openai::GPT_4O)
|
||||
.preamble("You are an agent with tools access, always use the tools")
|
||||
.max_tokens(1024)
|
||||
.tool(Add)
|
||||
.build();
|
||||
|
||||
for prompt in [
|
||||
"What tools do you have?",
|
||||
"Calculate 5 + 3",
|
||||
"What is 10 + 20?",
|
||||
"Add 100 and 200",
|
||||
] {
|
||||
println!("User: {}", prompt);
|
||||
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
use rig::completion::Prompt;
|
||||
use rig::providers;
|
||||
use rig::tool::Tool;
|
||||
use rig_derive::rig_tool;
|
||||
use tracing_subscriber;
|
||||
|
||||
// Example with description attribute
|
||||
#[rig_tool(description = "Perform basic arithmetic operations")]
|
||||
fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||
match operation.as_str() {
|
||||
"add" => Ok(x + y),
|
||||
"subtract" => Ok(x - y),
|
||||
"multiply" => Ok(x * y),
|
||||
"divide" => {
|
||||
if y == 0 {
|
||||
Err(rig::tool::ToolError::ToolCallError(
|
||||
"Division by zero".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(x / y)
|
||||
}
|
||||
}
|
||||
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||
format!("Unknown operation: {}", operation).into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt().pretty().init();
|
||||
|
||||
let calculator_agent = providers::openai::Client::from_env()
|
||||
.agent(providers::openai::GPT_4O)
|
||||
.preamble("You are an agent with tools access, always use the tools")
|
||||
.max_tokens(1024)
|
||||
.tool(Calculator)
|
||||
.build();
|
||||
|
||||
println!("Tool definition:");
|
||||
println!(
|
||||
"CALCULATOR: {}",
|
||||
serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap()
|
||||
);
|
||||
|
||||
for prompt in [
|
||||
"What tools do you have?",
|
||||
"Calculate 5 + 3",
|
||||
"What is 10 - 4?",
|
||||
"Multiply 6 and 7",
|
||||
"Divide 20 by 5",
|
||||
"What is 10 / 0?",
|
||||
] {
|
||||
println!("User: {}", prompt);
|
||||
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||
}
|
||||
}
|
|
@ -1,6 +1,15 @@
|
|||
extern crate proc_macro;
|
||||
|
||||
use convert_case::{Case, Casing};
|
||||
use proc_macro::TokenStream;
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
use quote::{format_ident, quote};
|
||||
use std::{collections::HashMap, ops::Deref};
|
||||
use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
parse_macro_input,
|
||||
punctuated::Punctuated,
|
||||
DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
|
||||
};
|
||||
|
||||
mod basic;
|
||||
mod custom;
|
||||
|
@ -19,3 +28,311 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
|
|||
.unwrap_or_else(syn::Error::into_compile_error)
|
||||
.into()
|
||||
}
|
||||
|
||||
struct MacroArgs {
|
||||
description: Option<String>,
|
||||
param_descriptions: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Parse for MacroArgs {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
let mut description = None;
|
||||
let mut param_descriptions = HashMap::new();
|
||||
|
||||
// If the input is empty, return default values
|
||||
if input.is_empty() {
|
||||
return Ok(MacroArgs {
|
||||
description,
|
||||
param_descriptions,
|
||||
});
|
||||
}
|
||||
|
||||
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
|
||||
|
||||
for meta in meta_list {
|
||||
match meta {
|
||||
Meta::NameValue(nv) => {
|
||||
let ident = nv.path.get_ident().unwrap().to_string();
|
||||
if let Expr::Lit(ExprLit {
|
||||
lit: Lit::Str(lit_str),
|
||||
..
|
||||
}) = nv.value
|
||||
{
|
||||
if ident.as_str() == "description" {
|
||||
description = Some(lit_str.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
Meta::List(list) if list.path.is_ident("params") => {
|
||||
let nested: Punctuated<Meta, Token![,]> =
|
||||
list.parse_args_with(Punctuated::parse_terminated)?;
|
||||
|
||||
for meta in nested {
|
||||
if let Meta::NameValue(nv) = meta {
|
||||
if let Expr::Lit(ExprLit {
|
||||
lit: Lit::Str(lit_str),
|
||||
..
|
||||
}) = nv.value
|
||||
{
|
||||
let param_name = nv.path.get_ident().unwrap().to_string();
|
||||
param_descriptions.insert(param_name, lit_str.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MacroArgs {
|
||||
description,
|
||||
param_descriptions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
|
||||
match ty {
|
||||
Type::Path(type_path) => {
|
||||
let segment = &type_path.path.segments[0];
|
||||
let type_name = segment.ident.to_string();
|
||||
|
||||
// Handle Vec types
|
||||
if type_name == "Vec" {
|
||||
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
|
||||
if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
|
||||
let inner_json_type = get_json_type(inner_type);
|
||||
return quote! {
|
||||
"type": "array",
|
||||
"items": { #inner_json_type }
|
||||
};
|
||||
}
|
||||
}
|
||||
return quote! { "type": "array" };
|
||||
}
|
||||
|
||||
// Handle primitive types
|
||||
match type_name.as_str() {
|
||||
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
|
||||
quote! { "type": "number" }
|
||||
}
|
||||
"String" | "str" => {
|
||||
quote! { "type": "string" }
|
||||
}
|
||||
"bool" => {
|
||||
quote! { "type": "boolean" }
|
||||
}
|
||||
// Handle other types as objects
|
||||
_ => {
|
||||
quote! { "type": "object" }
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
quote! { "type": "object" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Basic usage:
|
||||
/// ```rust
|
||||
/// use rig_derive::rig_tool;
|
||||
///
|
||||
/// #[rig_tool]
|
||||
/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||
/// Ok(a + b)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// With description:
|
||||
/// ```rust
|
||||
/// use rig_derive::rig_tool;
|
||||
///
|
||||
/// #[rig_tool(description = "Perform basic arithmetic operations")]
|
||||
/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||
/// match operation.as_str() {
|
||||
/// "add" => Ok(x + y),
|
||||
/// "subtract" => Ok(x - y),
|
||||
/// "multiply" => Ok(x * y),
|
||||
/// "divide" => Ok(x / y),
|
||||
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// With parameter descriptions:
|
||||
/// ```rust
|
||||
/// use rig_derive::rig_tool;
|
||||
///
|
||||
/// #[rig_tool(
|
||||
/// description = "A tool that performs string operations",
|
||||
/// params(
|
||||
/// text = "The input text to process",
|
||||
/// operation = "The operation to perform (uppercase, lowercase, reverse)"
|
||||
/// )
|
||||
/// )]
|
||||
/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||
/// match operation.as_str() {
|
||||
/// "uppercase" => Ok(text.to_uppercase()),
|
||||
/// "lowercase" => Ok(text.to_lowercase()),
|
||||
/// "reverse" => Ok(text.chars().rev().collect()),
|
||||
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_attribute]
|
||||
pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let args = parse_macro_input!(args as MacroArgs);
|
||||
let input_fn = parse_macro_input!(input as syn::ItemFn);
|
||||
|
||||
// Extract function details
|
||||
let fn_name = &input_fn.sig.ident;
|
||||
let fn_name_str = fn_name.to_string();
|
||||
let is_async = input_fn.sig.asyncness.is_some();
|
||||
|
||||
// Extract return type and get Output and Error types from Result<T, E>
|
||||
let return_type = &input_fn.sig.output;
|
||||
let output_type = match return_type {
|
||||
ReturnType::Type(_, ty) => {
|
||||
if let Type::Path(type_path) = ty.deref() {
|
||||
if let Some(last_segment) = type_path.path.segments.last() {
|
||||
if last_segment.ident == "Result" {
|
||||
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
|
||||
if args.args.len() == 2 {
|
||||
let output = args.args.first().unwrap();
|
||||
let error = args.args.last().unwrap();
|
||||
|
||||
// Convert the error type to a string for comparison
|
||||
let error_str = quote!(#error).to_string().replace(" ", "");
|
||||
if !error_str.contains("rig::tool::ToolError") {
|
||||
panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
|
||||
}
|
||||
|
||||
quote!(#output)
|
||||
} else {
|
||||
panic!("Expected Result with two type parameters");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected angle bracketed type parameters for Result");
|
||||
}
|
||||
} else {
|
||||
panic!("Return type must be a Result");
|
||||
}
|
||||
} else {
|
||||
panic!("Invalid return type");
|
||||
}
|
||||
} else {
|
||||
panic!("Invalid return type");
|
||||
}
|
||||
}
|
||||
_ => panic!("Function must have a return type"),
|
||||
};
|
||||
|
||||
// Generate PascalCase struct name from the function name
|
||||
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
|
||||
|
||||
// Use provided description or generate a default one
|
||||
let tool_description = match args.description {
|
||||
Some(desc) => quote! { #desc.to_string() },
|
||||
None => quote! { format!("Function to {}", Self::NAME) },
|
||||
};
|
||||
|
||||
// Extract parameter names, types, and descriptions
|
||||
let mut param_names = Vec::new();
|
||||
let mut param_types = Vec::new();
|
||||
let mut param_descriptions = Vec::new();
|
||||
let mut json_types = Vec::new();
|
||||
|
||||
for arg in input_fn.sig.inputs.iter() {
|
||||
if let syn::FnArg::Typed(pat_type) = arg {
|
||||
if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
|
||||
let param_name = ¶m_ident.ident;
|
||||
let param_name_str = param_name.to_string();
|
||||
let ty = &pat_type.ty;
|
||||
let default_parameter_description = format!("Parameter {}", param_name_str);
|
||||
let description = args
|
||||
.param_descriptions
|
||||
.get(¶m_name_str)
|
||||
.map(|s| s.to_owned())
|
||||
.unwrap_or(default_parameter_description);
|
||||
|
||||
param_names.push(param_name);
|
||||
param_types.push(ty);
|
||||
param_descriptions.push(description);
|
||||
json_types.push(get_json_type(ty));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let params_struct_name = format_ident!("{}Parameters", struct_name);
|
||||
let static_name = format_ident!("{}", fn_name_str.to_uppercase());
|
||||
|
||||
// Generate the call implementation based on whether the function is async
|
||||
let call_impl = if is_async {
|
||||
quote! {
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let expanded = quote! {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct #params_struct_name {
|
||||
#(#param_names: #param_types,)*
|
||||
}
|
||||
|
||||
#input_fn
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct #struct_name;
|
||||
|
||||
impl rig::tool::Tool for #struct_name {
|
||||
const NAME: &'static str = #fn_name_str;
|
||||
|
||||
type Args = #params_struct_name;
|
||||
type Output = #output_type;
|
||||
type Error = rig::tool::ToolError;
|
||||
|
||||
fn name(&self) -> String {
|
||||
#fn_name_str.to_string()
|
||||
}
|
||||
|
||||
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
|
||||
let parameters = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
#(
|
||||
stringify!(#param_names): {
|
||||
#json_types,
|
||||
"description": #param_descriptions
|
||||
}
|
||||
),*
|
||||
}
|
||||
});
|
||||
|
||||
rig::completion::ToolDefinition {
|
||||
name: #fn_name_str.to_string(),
|
||||
description: #tool_description.to_string(),
|
||||
parameters,
|
||||
}
|
||||
}
|
||||
|
||||
#call_impl
|
||||
}
|
||||
|
||||
pub static #static_name: #struct_name = #struct_name;
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
use rig::tool::Tool;
|
||||
use rig_derive::rig_tool;
|
||||
|
||||
#[rig_tool(
|
||||
description = "Perform basic arithmetic operations",
|
||||
params(
|
||||
x = "First number in the calculation",
|
||||
y = "Second number in the calculation",
|
||||
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||
)
|
||||
)]
|
||||
async fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||
match operation.as_str() {
|
||||
"add" => Ok(x + y),
|
||||
"subtract" => Ok(x - y),
|
||||
"multiply" => Ok(x * y),
|
||||
"divide" => {
|
||||
if y == 0 {
|
||||
Err(rig::tool::ToolError::ToolCallError(
|
||||
"Division by zero".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(x / y)
|
||||
}
|
||||
}
|
||||
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||
format!("Unknown operation: {}", operation).into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[rig_tool(
|
||||
description = "Perform basic arithmetic operations",
|
||||
params(
|
||||
x = "First number in the calculation",
|
||||
y = "Second number in the calculation",
|
||||
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||
)
|
||||
)]
|
||||
fn sync_calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||
match operation.as_str() {
|
||||
"add" => Ok(x + y),
|
||||
"subtract" => Ok(x - y),
|
||||
"multiply" => Ok(x * y),
|
||||
"divide" => {
|
||||
if y == 0 {
|
||||
Err(rig::tool::ToolError::ToolCallError(
|
||||
"Division by zero".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(x / y)
|
||||
}
|
||||
}
|
||||
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||
format!("Unknown operation: {}", operation).into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_calculator_tool() {
|
||||
// Create an instance of our tool
|
||||
let calculator = Calculator::default();
|
||||
|
||||
// Test tool information
|
||||
let definition = calculator.definition(String::default()).await;
|
||||
println!("{:?}", definition);
|
||||
assert_eq!(calculator.name(), "calculator");
|
||||
assert_eq!(
|
||||
definition.description,
|
||||
"Perform basic arithmetic operations"
|
||||
);
|
||||
|
||||
// Test valid operations
|
||||
let test_cases = vec![
|
||||
(
|
||||
CalculatorParameters {
|
||||
x: 5,
|
||||
y: 3,
|
||||
operation: "add".to_string(),
|
||||
},
|
||||
8,
|
||||
),
|
||||
(
|
||||
CalculatorParameters {
|
||||
x: 5,
|
||||
y: 3,
|
||||
operation: "subtract".to_string(),
|
||||
},
|
||||
2,
|
||||
),
|
||||
(
|
||||
CalculatorParameters {
|
||||
x: 5,
|
||||
y: 3,
|
||||
operation: "multiply".to_string(),
|
||||
},
|
||||
15,
|
||||
),
|
||||
(
|
||||
CalculatorParameters {
|
||||
x: 6,
|
||||
y: 2,
|
||||
operation: "divide".to_string(),
|
||||
},
|
||||
3,
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
let result = calculator.call(input).await.unwrap();
|
||||
assert_eq!(result, serde_json::json!(expected));
|
||||
}
|
||||
|
||||
// Test division by zero
|
||||
let div_zero = CalculatorParameters {
|
||||
x: 5,
|
||||
y: 0,
|
||||
operation: "divide".to_string(),
|
||||
};
|
||||
let err = calculator.call(div_zero).await.unwrap_err();
|
||||
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||
|
||||
// Test invalid operation
|
||||
let invalid_op = CalculatorParameters {
|
||||
x: 5,
|
||||
y: 3,
|
||||
operation: "power".to_string(),
|
||||
};
|
||||
let err = calculator.call(invalid_op).await.unwrap_err();
|
||||
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||
|
||||
// Test sync calculator
|
||||
let sync_calculator = SyncCalculator::default();
|
||||
let result = sync_calculator
|
||||
.call(SyncCalculatorParameters {
|
||||
x: 5,
|
||||
y: 3,
|
||||
operation: "add".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, serde_json::json!(8));
|
||||
}
|
Loading…
Reference in New Issue