diff --git a/Cargo.lock b/Cargo.lock index d647af4..5a7d564 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8699,10 +8699,16 @@ dependencies = [ name = "rig-derive" version = "0.1.0" dependencies = [ + "convert_case 0.6.0", "indoc", "proc-macro2", "quote", + "rig-core 0.11.0", + "serde", + "serde_json", "syn 2.0.100", + "tokio", + "tracing-subscriber", ] [[package]] diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index 890bf53..d5228ae 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -7,10 +7,35 @@ description = "Internal crate that implements Rig derive macros." repository = "https://github.com/0xPlaygrounds/rig" [dependencies] +convert_case = { version = "0.6.0" } indoc = "2.0.5" proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } quote = "1.0.37" +serde_json = "1.0.108" syn = { version = "2.0.79", features = ["full"]} [lib] proc-macro = true + +[dev-dependencies] +rig-core = { path = "../../rig-core" } +serde = "1.0" +serde_json = "1.0.108" +tokio = { version = "1.44.0", features = ["full"] } +tracing-subscriber = "0.3.0" + +[[example]] +name = "simple" +path = "examples/rig_tool/simple.rs" + +[[example]] +name = "with_description" +path = "examples/rig_tool/with_description.rs" + +[[example]] +name = "full" +path = "examples/rig_tool/full.rs" + +[[example]] +name = "async_tool" +path = "examples/rig_tool/async_tool.rs" diff --git a/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs b/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs new file mode 100644 index 0000000..bc2fc5d --- /dev/null +++ b/rig-core/rig-core-derive/examples/rig_tool/async_tool.rs @@ -0,0 +1,53 @@ +use rig::completion::Prompt; +use rig::providers; +use rig::tool::Tool; +use rig_derive::rig_tool; +use std::time::Duration; +use tracing_subscriber; + +// Example demonstrating async tool usage +#[rig_tool( + description = "A tool that simulates an async operation", + params( + input = "Input value to process", + delay_ms = "Delay in milliseconds before returning result" + ) +)] +async fn async_operation(input: String, delay_ms: u64) -> Result { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + + Ok(format!( + "Processed after {}ms: {}", + delay_ms, + input.to_uppercase() + )) +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().pretty().init(); + + let async_agent = providers::openai::Client::from_env() + .agent(providers::openai::GPT_4O) + .preamble("You are an agent with tools access, always use the tools") + .max_tokens(1024) + .tool(AsyncOperation) + .build(); + + println!("Tool definition:"); + println!( + "ASYNCOPERATION: {}", + serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap() + ); + + for prompt in [ + "What tools do you have?", + "Process the text 'hello world' with a delay of 1000ms", + "Process the text 'async operation' with a delay of 500ms", + "Process the text 'concurrent calls' with a delay of 200ms", + "Process the text 'error handling' with a delay of 'not a number'", + ] { + println!("User: {}", prompt); + println!("Agent: {}", async_agent.prompt(prompt).await.unwrap()); + } +} diff --git a/rig-core/rig-core-derive/examples/rig_tool/full.rs b/rig-core/rig-core-derive/examples/rig_tool/full.rs new file mode 100644 index 0000000..d012791 --- /dev/null +++ b/rig-core/rig-core-derive/examples/rig_tool/full.rs @@ -0,0 +1,58 @@ +use rig::completion::Prompt; +use rig::providers; +use rig::tool::Tool; +use rig_derive::rig_tool; +use tracing_subscriber; + +// Example with full attributes including parameter descriptions +#[rig_tool( + description = "A tool that performs string operations", + params( + text = "The input text to process", + operation = "The operation to perform (uppercase, lowercase, reverse)", + ) +)] +fn string_processor(text: String, operation: String) -> Result { + let result = match operation.as_str() { + "uppercase" => text.to_uppercase(), + "lowercase" => text.to_lowercase(), + "reverse" => text.chars().rev().collect(), + _ => { + return Err(rig::tool::ToolError::ToolCallError( + format!("Unknown operation: {}", operation).into(), + )) + } + }; + + Ok(result) +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().pretty().init(); + + let string_agent = providers::openai::Client::from_env() + .agent(providers::openai::GPT_4O) + .preamble("You are an agent with tools access, always use the tools") + .max_tokens(1024) + .tool(StringProcessor) + .build(); + + println!("Tool definition:"); + println!( + "STRINGPROCESSOR: {}", + serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap() + ); + + for prompt in [ + "What tools do you have?", + "Convert 'hello world' to uppercase", + "Convert 'HELLO WORLD' to lowercase", + "Reverse the string 'hello world'", + "Convert 'hello world' to uppercase and repeat it 3 times", + "Perform an invalid operation on 'hello world'", + ] { + println!("User: {}", prompt); + println!("Agent: {}", string_agent.prompt(prompt).await.unwrap()); + } +} diff --git a/rig-core/rig-core-derive/examples/rig_tool/simple.rs b/rig-core/rig-core-derive/examples/rig_tool/simple.rs new file mode 100644 index 0000000..50309f8 --- /dev/null +++ b/rig-core/rig-core-derive/examples/rig_tool/simple.rs @@ -0,0 +1,71 @@ +use rig::completion::Prompt; +use rig::providers; +use rig_derive::rig_tool; +use tracing_subscriber; + +// Simple example with no attributes +#[rig_tool] +fn add(a: i32, b: i32) -> Result { + Ok(a + b) +} + +#[rig_tool] +fn subtract(a: i32, b: i32) -> Result { + Ok(a - b) +} + +#[rig_tool] +fn multiply(a: i32, b: i32) -> Result { + Ok(a * b) +} + +#[rig_tool] +fn divide(a: i32, b: i32) -> Result { + if b == 0 { + Err(rig::tool::ToolError::ToolCallError( + "Division by zero".into(), + )) + } else { + Ok(a / b) + } +} + +#[rig_tool] +fn answer_secret_question() -> Result<(bool, bool, bool, bool, bool), rig::tool::ToolError> { + Ok((false, false, true, false, false)) +} + +#[rig_tool] +fn how_many_rs(s: String) -> Result { + Ok(s.chars() + .filter(|c| *c == 'r' || *c == 'R') + .collect::>() + .len()) +} + +#[rig_tool] +fn sum_numbers(numbers: Vec) -> Result { + Ok(numbers.iter().sum()) +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().pretty().init(); + + let calculator_agent = providers::openai::Client::from_env() + .agent(providers::openai::GPT_4O) + .preamble("You are an agent with tools access, always use the tools") + .max_tokens(1024) + .tool(Add) + .build(); + + for prompt in [ + "What tools do you have?", + "Calculate 5 + 3", + "What is 10 + 20?", + "Add 100 and 200", + ] { + println!("User: {}", prompt); + println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap()); + } +} diff --git a/rig-core/rig-core-derive/examples/rig_tool/with_description.rs b/rig-core/rig-core-derive/examples/rig_tool/with_description.rs new file mode 100644 index 0000000..9225bd3 --- /dev/null +++ b/rig-core/rig-core-derive/examples/rig_tool/with_description.rs @@ -0,0 +1,57 @@ +use rig::completion::Prompt; +use rig::providers; +use rig::tool::Tool; +use rig_derive::rig_tool; +use tracing_subscriber; + +// Example with description attribute +#[rig_tool(description = "Perform basic arithmetic operations")] +fn calculator(x: i32, y: i32, operation: String) -> Result { + match operation.as_str() { + "add" => Ok(x + y), + "subtract" => Ok(x - y), + "multiply" => Ok(x * y), + "divide" => { + if y == 0 { + Err(rig::tool::ToolError::ToolCallError( + "Division by zero".into(), + )) + } else { + Ok(x / y) + } + } + _ => Err(rig::tool::ToolError::ToolCallError( + format!("Unknown operation: {}", operation).into(), + )), + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().pretty().init(); + + let calculator_agent = providers::openai::Client::from_env() + .agent(providers::openai::GPT_4O) + .preamble("You are an agent with tools access, always use the tools") + .max_tokens(1024) + .tool(Calculator) + .build(); + + println!("Tool definition:"); + println!( + "CALCULATOR: {}", + serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap() + ); + + for prompt in [ + "What tools do you have?", + "Calculate 5 + 3", + "What is 10 - 4?", + "Multiply 6 and 7", + "Divide 20 by 5", + "What is 10 / 0?", + ] { + println!("User: {}", prompt); + println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap()); + } +} diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 4ce20cf..7d800bc 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -1,6 +1,15 @@ extern crate proc_macro; + +use convert_case::{Case, Casing}; use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput}; +use quote::{format_ident, quote}; +use std::{collections::HashMap, ops::Deref}; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + punctuated::Punctuated, + DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type, +}; mod basic; mod custom; @@ -19,3 +28,311 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +struct MacroArgs { + description: Option, + param_descriptions: HashMap, +} + +impl Parse for MacroArgs { + fn parse(input: ParseStream) -> syn::Result { + let mut description = None; + let mut param_descriptions = HashMap::new(); + + // If the input is empty, return default values + if input.is_empty() { + return Ok(MacroArgs { + description, + param_descriptions, + }); + } + + let meta_list: Punctuated = Punctuated::parse_terminated(input)?; + + for meta in meta_list { + match meta { + Meta::NameValue(nv) => { + let ident = nv.path.get_ident().unwrap().to_string(); + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + if ident.as_str() == "description" { + description = Some(lit_str.value()); + } + } + } + Meta::List(list) if list.path.is_ident("params") => { + let nested: Punctuated = + list.parse_args_with(Punctuated::parse_terminated)?; + + for meta in nested { + if let Meta::NameValue(nv) = meta { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + let param_name = nv.path.get_ident().unwrap().to_string(); + param_descriptions.insert(param_name, lit_str.value()); + } + } + } + } + _ => {} + } + } + + Ok(MacroArgs { + description, + param_descriptions, + }) + } +} + +fn get_json_type(ty: &Type) -> proc_macro2::TokenStream { + match ty { + Type::Path(type_path) => { + let segment = &type_path.path.segments[0]; + let type_name = segment.ident.to_string(); + + // Handle Vec types + if type_name == "Vec" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let syn::GenericArgument::Type(inner_type) = &args.args[0] { + let inner_json_type = get_json_type(inner_type); + return quote! { + "type": "array", + "items": { #inner_json_type } + }; + } + } + return quote! { "type": "array" }; + } + + // Handle primitive types + match type_name.as_str() { + "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => { + quote! { "type": "number" } + } + "String" | "str" => { + quote! { "type": "string" } + } + "bool" => { + quote! { "type": "boolean" } + } + // Handle other types as objects + _ => { + quote! { "type": "object" } + } + } + } + _ => { + quote! { "type": "object" } + } + } +} + +/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`. +/// +/// # Examples +/// +/// Basic usage: +/// ```rust +/// use rig_derive::rig_tool; +/// +/// #[rig_tool] +/// fn add(a: i32, b: i32) -> Result { +/// Ok(a + b) +/// } +/// ``` +/// +/// With description: +/// ```rust +/// use rig_derive::rig_tool; +/// +/// #[rig_tool(description = "Perform basic arithmetic operations")] +/// fn calculator(x: i32, y: i32, operation: String) -> Result { +/// match operation.as_str() { +/// "add" => Ok(x + y), +/// "subtract" => Ok(x - y), +/// "multiply" => Ok(x * y), +/// "divide" => Ok(x / y), +/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())), +/// } +/// } +/// ``` +/// +/// With parameter descriptions: +/// ```rust +/// use rig_derive::rig_tool; +/// +/// #[rig_tool( +/// description = "A tool that performs string operations", +/// params( +/// text = "The input text to process", +/// operation = "The operation to perform (uppercase, lowercase, reverse)" +/// ) +/// )] +/// fn string_processor(text: String, operation: String) -> Result { +/// match operation.as_str() { +/// "uppercase" => Ok(text.to_uppercase()), +/// "lowercase" => Ok(text.to_lowercase()), +/// "reverse" => Ok(text.chars().rev().collect()), +/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())), +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as MacroArgs); + let input_fn = parse_macro_input!(input as syn::ItemFn); + + // Extract function details + let fn_name = &input_fn.sig.ident; + let fn_name_str = fn_name.to_string(); + let is_async = input_fn.sig.asyncness.is_some(); + + // Extract return type and get Output and Error types from Result + let return_type = &input_fn.sig.output; + let output_type = match return_type { + ReturnType::Type(_, ty) => { + if let Type::Path(type_path) = ty.deref() { + if let Some(last_segment) = type_path.path.segments.last() { + if last_segment.ident == "Result" { + if let PathArguments::AngleBracketed(args) = &last_segment.arguments { + if args.args.len() == 2 { + let output = args.args.first().unwrap(); + let error = args.args.last().unwrap(); + + // Convert the error type to a string for comparison + let error_str = quote!(#error).to_string().replace(" ", ""); + if !error_str.contains("rig::tool::ToolError") { + panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str); + } + + quote!(#output) + } else { + panic!("Expected Result with two type parameters"); + } + } else { + panic!("Expected angle bracketed type parameters for Result"); + } + } else { + panic!("Return type must be a Result"); + } + } else { + panic!("Invalid return type"); + } + } else { + panic!("Invalid return type"); + } + } + _ => panic!("Function must have a return type"), + }; + + // Generate PascalCase struct name from the function name + let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) }); + + // Use provided description or generate a default one + let tool_description = match args.description { + Some(desc) => quote! { #desc.to_string() }, + None => quote! { format!("Function to {}", Self::NAME) }, + }; + + // Extract parameter names, types, and descriptions + let mut param_names = Vec::new(); + let mut param_types = Vec::new(); + let mut param_descriptions = Vec::new(); + let mut json_types = Vec::new(); + + for arg in input_fn.sig.inputs.iter() { + if let syn::FnArg::Typed(pat_type) = arg { + if let syn::Pat::Ident(param_ident) = &*pat_type.pat { + let param_name = ¶m_ident.ident; + let param_name_str = param_name.to_string(); + let ty = &pat_type.ty; + let default_parameter_description = format!("Parameter {}", param_name_str); + let description = args + .param_descriptions + .get(¶m_name_str) + .map(|s| s.to_owned()) + .unwrap_or(default_parameter_description); + + param_names.push(param_name); + param_types.push(ty); + param_descriptions.push(description); + json_types.push(get_json_type(ty)); + } + } + } + + let params_struct_name = format_ident!("{}Parameters", struct_name); + let static_name = format_ident!("{}", fn_name_str.to_uppercase()); + + // Generate the call implementation based on whether the function is async + let call_impl = if is_async { + quote! { + async fn call(&self, args: Self::Args) -> Result { + #fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into())) + } + } + } else { + quote! { + async fn call(&self, args: Self::Args) -> Result { + #fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into())) + } + } + }; + + let expanded = quote! { + #[derive(serde::Deserialize)] + struct #params_struct_name { + #(#param_names: #param_types,)* + } + + #input_fn + + #[derive(Default)] + pub(crate) struct #struct_name; + + impl rig::tool::Tool for #struct_name { + const NAME: &'static str = #fn_name_str; + + type Args = #params_struct_name; + type Output = #output_type; + type Error = rig::tool::ToolError; + + fn name(&self) -> String { + #fn_name_str.to_string() + } + + async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition { + let parameters = serde_json::json!({ + "type": "object", + "properties": { + #( + stringify!(#param_names): { + #json_types, + "description": #param_descriptions + } + ),* + } + }); + + rig::completion::ToolDefinition { + name: #fn_name_str.to_string(), + description: #tool_description.to_string(), + parameters, + } + } + + #call_impl + } + + pub static #static_name: #struct_name = #struct_name; + }; + + TokenStream::from(expanded) +} diff --git a/rig-core/rig-core-derive/tests/calculator.rs b/rig-core/rig-core-derive/tests/calculator.rs new file mode 100644 index 0000000..97e6b8c --- /dev/null +++ b/rig-core/rig-core-derive/tests/calculator.rs @@ -0,0 +1,145 @@ +use rig::tool::Tool; +use rig_derive::rig_tool; + +#[rig_tool( + description = "Perform basic arithmetic operations", + params( + x = "First number in the calculation", + y = "Second number in the calculation", + operation = "The operation to perform (add, subtract, multiply, divide)" + ) +)] +async fn calculator(x: i32, y: i32, operation: String) -> Result { + match operation.as_str() { + "add" => Ok(x + y), + "subtract" => Ok(x - y), + "multiply" => Ok(x * y), + "divide" => { + if y == 0 { + Err(rig::tool::ToolError::ToolCallError( + "Division by zero".into(), + )) + } else { + Ok(x / y) + } + } + _ => Err(rig::tool::ToolError::ToolCallError( + format!("Unknown operation: {}", operation).into(), + )), + } +} + +#[rig_tool( + description = "Perform basic arithmetic operations", + params( + x = "First number in the calculation", + y = "Second number in the calculation", + operation = "The operation to perform (add, subtract, multiply, divide)" + ) +)] +fn sync_calculator(x: i32, y: i32, operation: String) -> Result { + match operation.as_str() { + "add" => Ok(x + y), + "subtract" => Ok(x - y), + "multiply" => Ok(x * y), + "divide" => { + if y == 0 { + Err(rig::tool::ToolError::ToolCallError( + "Division by zero".into(), + )) + } else { + Ok(x / y) + } + } + _ => Err(rig::tool::ToolError::ToolCallError( + format!("Unknown operation: {}", operation).into(), + )), + } +} + +#[tokio::test] +async fn test_calculator_tool() { + // Create an instance of our tool + let calculator = Calculator::default(); + + // Test tool information + let definition = calculator.definition(String::default()).await; + println!("{:?}", definition); + assert_eq!(calculator.name(), "calculator"); + assert_eq!( + definition.description, + "Perform basic arithmetic operations" + ); + + // Test valid operations + let test_cases = vec![ + ( + CalculatorParameters { + x: 5, + y: 3, + operation: "add".to_string(), + }, + 8, + ), + ( + CalculatorParameters { + x: 5, + y: 3, + operation: "subtract".to_string(), + }, + 2, + ), + ( + CalculatorParameters { + x: 5, + y: 3, + operation: "multiply".to_string(), + }, + 15, + ), + ( + CalculatorParameters { + x: 6, + y: 2, + operation: "divide".to_string(), + }, + 3, + ), + ]; + + for (input, expected) in test_cases { + let result = calculator.call(input).await.unwrap(); + assert_eq!(result, serde_json::json!(expected)); + } + + // Test division by zero + let div_zero = CalculatorParameters { + x: 5, + y: 0, + operation: "divide".to_string(), + }; + let err = calculator.call(div_zero).await.unwrap_err(); + assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); + + // Test invalid operation + let invalid_op = CalculatorParameters { + x: 5, + y: 3, + operation: "power".to_string(), + }; + let err = calculator.call(invalid_op).await.unwrap_err(); + assert!(matches!(err, rig::tool::ToolError::ToolCallError(_))); + + // Test sync calculator + let sync_calculator = SyncCalculator::default(); + let result = sync_calculator + .call(SyncCalculatorParameters { + x: 5, + y: 3, + operation: "add".to_string(), + }) + .await + .unwrap(); + + assert_eq!(result, serde_json::json!(8)); +} diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 73cdfc2..4b9a32a 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -60,6 +60,35 @@ pub trait EmbeddingModel: Clone + Sync + Send { } } +/// Trait for embedding models that can generate embeddings for images. +pub trait ImageEmbeddingModel: Clone + Sync + Send { + /// The maximum number of images that can be embedded in a single request. + const MAX_DOCUMENTS: usize; + + /// The number of dimensions in the embedding vector. + fn ndims(&self) -> usize; + + /// Embed multiple images in a single request from bytes. + fn embed_images( + &self, + images: impl IntoIterator> + Send, + ) -> impl std::future::Future, EmbeddingError>> + Send; + + /// Embed a single image from bytes. + fn embed_image<'a>( + &'a self, + bytes: &'a [u8], + ) -> impl std::future::Future> + Send { + async move { + Ok(self + .embed_images(vec![bytes.to_owned()]) + .await? + .pop() + .expect("There should be at least one embedding")) + } + } +} + /// Struct that holds a single document and its embedding. #[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 87cdd32..b14d56c 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } + /// #[cfg_attr(feature = "worker", worker::send)] async fn embed_texts( &self, documents: impl IntoIterator + Send, ) -> Result, EmbeddingError> { - let documents: Vec<_> = documents.into_iter().collect(); - let mut request_body = json!({ - "model": format!("models/{}", self.model), - "content": { - "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), - }, - }); + let documents: Vec = documents.into_iter().collect(); - if let Some(ndims) = self.ndims { - request_body["output_dimensionality"] = json!(ndims); - } + // Google batch embed requests. See docstrings for API ref link. + let requests: Vec<_> = documents + .iter() + .map(|doc| { + json!({ + "model": format!("models/{}", self.model), + "content": json!({ + "parts": [json!({ + "text": doc.to_string() + })] + }), + "output_dimensionality": self.ndims, + }) + }) + .collect(); + + let request_body = json!({ "requests": requests }); + + println!("{}", serde_json::to_string_pretty(&request_body).unwrap()); let response = self .client - .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model)) .json(&request_body) .send() .await? @@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel { match response { ApiResponse::Ok(response) => { - let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); - Ok(documents + let docs = documents .into_iter() - .zip(response.embedding.values.chunks(chunk_size)) + .zip(response.embeddings) .map(|(document, embedding)| embeddings::Embedding { document, - vec: embedding.to_vec(), + vec: embedding.values, }) - .collect()) + .collect(); + + Ok(docs) } ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } @@ -196,7 +208,7 @@ mod gemini_api_types { #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { - pub embedding: EmbeddingValues, + pub embeddings: Vec, } #[derive(Debug, Deserialize)] diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index 949ee01..6777491 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -1,6 +1,6 @@ // ================================================================ //! xAI Completion Integration -//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions) +//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions) // ================================================================ use crate::{ diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 8c4b8cc..1d9a24d 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -204,6 +204,28 @@ where } } +#[cfg(feature = "mcp")] +impl From<&mcp_core::types::Tool> for ToolDefinition { + fn from(val: &mcp_core::types::Tool) -> Self { + Self { + name: val.name.to_owned(), + description: val.description.to_owned().unwrap_or_default(), + parameters: val.input_schema.to_owned(), + } + } +} + +#[cfg(feature = "mcp")] +impl From for ToolDefinition { + fn from(val: mcp_core::types::Tool) -> Self { + Self { + name: val.name, + description: val.description.unwrap_or_default(), + parameters: val.input_schema, + } + } +} + #[cfg(feature = "mcp")] #[derive(Debug, thiserror::Error)] #[error("MCP tool error: {0}")] diff --git a/rig-lancedb/README.md b/rig-lancedb/README.md index fd7be50..caa56c1 100644 --- a/rig-lancedb/README.md +++ b/rig-lancedb/README.md @@ -17,6 +17,9 @@ ## Rig-Lancedb This companion crate implements a Rig vector store based on Lancedb. +## Pre-requisites +If you are using `rig-lancedb` locally, you must ensure you have `protoc` (the [Protobuf Compiler](https://protobuf.dev/installation/)) installed. + ## Usage Add the companion crate to your `Cargo.toml`, along with the rig-core crate: diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index a4415ba..b7885e0 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -38,8 +38,15 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let table = db - .create_table( + let table = if db + .table_names() + .execute() + .await? + .contains(&"definitions".to_string()) + { + db.open_table("definitions").execute().await? + } else { + db.create_table( "definitions", RecordBatchIterator::new( vec![as_record_batch(embeddings, model.ndims())], @@ -47,16 +54,19 @@ async fn main() -> Result<(), anyhow::Error> { ), ) .execute() - .await?; + .await? + }; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - table - .create_index( - &["embedding"], - lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), - ) - .execute() - .await?; + if table.index_stats("embedding").await?.is_none() { + table + .create_index( + &["embedding"], + lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), + ) + .execute() + .await?; + } // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default(); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 5011238..a5dfd64 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -32,8 +32,15 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - let table = db - .create_table( + let table = if db + .table_names() + .execute() + .await? + .contains(&"definitions".to_string()) + { + db.open_table("definitions").execute().await? + } else { + db.create_table( "definitions", RecordBatchIterator::new( vec![as_record_batch(embeddings, model.ndims())], @@ -41,7 +48,8 @@ async fn main() -> Result<(), anyhow::Error> { ), ) .execute() - .await?; + .await? + }; let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 61267e8..8f70d9f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -44,8 +44,15 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let table = db - .create_table( + let table = if db + .table_names() + .execute() + .await? + .contains(&"definitions".to_string()) + { + db.open_table("definitions").execute().await? + } else { + db.create_table( "definitions", RecordBatchIterator::new( vec![as_record_batch(embeddings, model.ndims())], @@ -53,21 +60,24 @@ async fn main() -> Result<(), anyhow::Error> { ), ) .execute() - .await?; + .await? + }; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - table - .create_index( - &["embedding"], - lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - ), - ) - .execute() - .await?; + if table.index_stats("embedding").await?.is_none() { + table + .create_index( + &["embedding"], + lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. + .distance_type(DistanceType::Cosine), + ), + ) + .execute() + .await?; + } // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default().distance_type(DistanceType::Cosine);