Merge branch 'main' of https://github.com/0xPlaygrounds/rig into fix/multiple-tool-calling

This commit is contained in:
0xMochan 2025-04-14 16:45:09 -07:00
commit 368f9df033
16 changed files with 873 additions and 47 deletions

6
Cargo.lock generated
View File

@ -8699,10 +8699,16 @@ dependencies = [
name = "rig-derive" name = "rig-derive"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"convert_case 0.6.0",
"indoc", "indoc",
"proc-macro2", "proc-macro2",
"quote", "quote",
"rig-core 0.11.0",
"serde",
"serde_json",
"syn 2.0.100", "syn 2.0.100",
"tokio",
"tracing-subscriber",
] ]
[[package]] [[package]]

View File

@ -7,10 +7,35 @@ description = "Internal crate that implements Rig derive macros."
repository = "https://github.com/0xPlaygrounds/rig" repository = "https://github.com/0xPlaygrounds/rig"
[dependencies] [dependencies]
convert_case = { version = "0.6.0" }
indoc = "2.0.5" indoc = "2.0.5"
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
quote = "1.0.37" quote = "1.0.37"
serde_json = "1.0.108"
syn = { version = "2.0.79", features = ["full"]} syn = { version = "2.0.79", features = ["full"]}
[lib] [lib]
proc-macro = true proc-macro = true
[dev-dependencies]
rig-core = { path = "../../rig-core" }
serde = "1.0"
serde_json = "1.0.108"
tokio = { version = "1.44.0", features = ["full"] }
tracing-subscriber = "0.3.0"
[[example]]
name = "simple"
path = "examples/rig_tool/simple.rs"
[[example]]
name = "with_description"
path = "examples/rig_tool/with_description.rs"
[[example]]
name = "full"
path = "examples/rig_tool/full.rs"
[[example]]
name = "async_tool"
path = "examples/rig_tool/async_tool.rs"

View File

@ -0,0 +1,53 @@
use rig::completion::Prompt;
use rig::providers;
use rig::tool::Tool;
use rig_derive::rig_tool;
use std::time::Duration;
use tracing_subscriber;
// Example demonstrating async tool usage
#[rig_tool(
description = "A tool that simulates an async operation",
params(
input = "Input value to process",
delay_ms = "Delay in milliseconds before returning result"
)
)]
async fn async_operation(input: String, delay_ms: u64) -> Result<String, rig::tool::ToolError> {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
Ok(format!(
"Processed after {}ms: {}",
delay_ms,
input.to_uppercase()
))
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().pretty().init();
let async_agent = providers::openai::Client::from_env()
.agent(providers::openai::GPT_4O)
.preamble("You are an agent with tools access, always use the tools")
.max_tokens(1024)
.tool(AsyncOperation)
.build();
println!("Tool definition:");
println!(
"ASYNCOPERATION: {}",
serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap()
);
for prompt in [
"What tools do you have?",
"Process the text 'hello world' with a delay of 1000ms",
"Process the text 'async operation' with a delay of 500ms",
"Process the text 'concurrent calls' with a delay of 200ms",
"Process the text 'error handling' with a delay of 'not a number'",
] {
println!("User: {}", prompt);
println!("Agent: {}", async_agent.prompt(prompt).await.unwrap());
}
}

View File

@ -0,0 +1,58 @@
use rig::completion::Prompt;
use rig::providers;
use rig::tool::Tool;
use rig_derive::rig_tool;
use tracing_subscriber;
// Example with full attributes including parameter descriptions
#[rig_tool(
description = "A tool that performs string operations",
params(
text = "The input text to process",
operation = "The operation to perform (uppercase, lowercase, reverse)",
)
)]
fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
let result = match operation.as_str() {
"uppercase" => text.to_uppercase(),
"lowercase" => text.to_lowercase(),
"reverse" => text.chars().rev().collect(),
_ => {
return Err(rig::tool::ToolError::ToolCallError(
format!("Unknown operation: {}", operation).into(),
))
}
};
Ok(result)
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().pretty().init();
let string_agent = providers::openai::Client::from_env()
.agent(providers::openai::GPT_4O)
.preamble("You are an agent with tools access, always use the tools")
.max_tokens(1024)
.tool(StringProcessor)
.build();
println!("Tool definition:");
println!(
"STRINGPROCESSOR: {}",
serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap()
);
for prompt in [
"What tools do you have?",
"Convert 'hello world' to uppercase",
"Convert 'HELLO WORLD' to lowercase",
"Reverse the string 'hello world'",
"Convert 'hello world' to uppercase and repeat it 3 times",
"Perform an invalid operation on 'hello world'",
] {
println!("User: {}", prompt);
println!("Agent: {}", string_agent.prompt(prompt).await.unwrap());
}
}

View File

@ -0,0 +1,71 @@
use rig::completion::Prompt;
use rig::providers;
use rig_derive::rig_tool;
use tracing_subscriber;
// Simple example with no attributes
#[rig_tool]
fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
Ok(a + b)
}
#[rig_tool]
fn subtract(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
Ok(a - b)
}
#[rig_tool]
fn multiply(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
Ok(a * b)
}
#[rig_tool]
fn divide(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
if b == 0 {
Err(rig::tool::ToolError::ToolCallError(
"Division by zero".into(),
))
} else {
Ok(a / b)
}
}
#[rig_tool]
fn answer_secret_question() -> Result<(bool, bool, bool, bool, bool), rig::tool::ToolError> {
Ok((false, false, true, false, false))
}
#[rig_tool]
fn how_many_rs(s: String) -> Result<usize, rig::tool::ToolError> {
Ok(s.chars()
.filter(|c| *c == 'r' || *c == 'R')
.collect::<Vec<_>>()
.len())
}
#[rig_tool]
fn sum_numbers(numbers: Vec<i64>) -> Result<i64, rig::tool::ToolError> {
Ok(numbers.iter().sum())
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().pretty().init();
let calculator_agent = providers::openai::Client::from_env()
.agent(providers::openai::GPT_4O)
.preamble("You are an agent with tools access, always use the tools")
.max_tokens(1024)
.tool(Add)
.build();
for prompt in [
"What tools do you have?",
"Calculate 5 + 3",
"What is 10 + 20?",
"Add 100 and 200",
] {
println!("User: {}", prompt);
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
}
}

View File

@ -0,0 +1,57 @@
use rig::completion::Prompt;
use rig::providers;
use rig::tool::Tool;
use rig_derive::rig_tool;
use tracing_subscriber;
// Example with description attribute
#[rig_tool(description = "Perform basic arithmetic operations")]
fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
match operation.as_str() {
"add" => Ok(x + y),
"subtract" => Ok(x - y),
"multiply" => Ok(x * y),
"divide" => {
if y == 0 {
Err(rig::tool::ToolError::ToolCallError(
"Division by zero".into(),
))
} else {
Ok(x / y)
}
}
_ => Err(rig::tool::ToolError::ToolCallError(
format!("Unknown operation: {}", operation).into(),
)),
}
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().pretty().init();
let calculator_agent = providers::openai::Client::from_env()
.agent(providers::openai::GPT_4O)
.preamble("You are an agent with tools access, always use the tools")
.max_tokens(1024)
.tool(Calculator)
.build();
println!("Tool definition:");
println!(
"CALCULATOR: {}",
serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap()
);
for prompt in [
"What tools do you have?",
"Calculate 5 + 3",
"What is 10 - 4?",
"Multiply 6 and 7",
"Divide 20 by 5",
"What is 10 / 0?",
] {
println!("User: {}", prompt);
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
}
}

View File

@ -1,6 +1,15 @@
extern crate proc_macro; extern crate proc_macro;
use convert_case::{Case, Casing};
use proc_macro::TokenStream; use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput}; use quote::{format_ident, quote};
use std::{collections::HashMap, ops::Deref};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
};
mod basic; mod basic;
mod custom; mod custom;
@ -19,3 +28,311 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
.unwrap_or_else(syn::Error::into_compile_error) .unwrap_or_else(syn::Error::into_compile_error)
.into() .into()
} }
struct MacroArgs {
description: Option<String>,
param_descriptions: HashMap<String, String>,
}
impl Parse for MacroArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut description = None;
let mut param_descriptions = HashMap::new();
// If the input is empty, return default values
if input.is_empty() {
return Ok(MacroArgs {
description,
param_descriptions,
});
}
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
for meta in meta_list {
match meta {
Meta::NameValue(nv) => {
let ident = nv.path.get_ident().unwrap().to_string();
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) = nv.value
{
if ident.as_str() == "description" {
description = Some(lit_str.value());
}
}
}
Meta::List(list) if list.path.is_ident("params") => {
let nested: Punctuated<Meta, Token![,]> =
list.parse_args_with(Punctuated::parse_terminated)?;
for meta in nested {
if let Meta::NameValue(nv) = meta {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) = nv.value
{
let param_name = nv.path.get_ident().unwrap().to_string();
param_descriptions.insert(param_name, lit_str.value());
}
}
}
}
_ => {}
}
}
Ok(MacroArgs {
description,
param_descriptions,
})
}
}
fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
match ty {
Type::Path(type_path) => {
let segment = &type_path.path.segments[0];
let type_name = segment.ident.to_string();
// Handle Vec types
if type_name == "Vec" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
let inner_json_type = get_json_type(inner_type);
return quote! {
"type": "array",
"items": { #inner_json_type }
};
}
}
return quote! { "type": "array" };
}
// Handle primitive types
match type_name.as_str() {
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
quote! { "type": "number" }
}
"String" | "str" => {
quote! { "type": "string" }
}
"bool" => {
quote! { "type": "boolean" }
}
// Handle other types as objects
_ => {
quote! { "type": "object" }
}
}
}
_ => {
quote! { "type": "object" }
}
}
}
/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
///
/// # Examples
///
/// Basic usage:
/// ```rust
/// use rig_derive::rig_tool;
///
/// #[rig_tool]
/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
/// Ok(a + b)
/// }
/// ```
///
/// With description:
/// ```rust
/// use rig_derive::rig_tool;
///
/// #[rig_tool(description = "Perform basic arithmetic operations")]
/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
/// match operation.as_str() {
/// "add" => Ok(x + y),
/// "subtract" => Ok(x - y),
/// "multiply" => Ok(x * y),
/// "divide" => Ok(x / y),
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
/// }
/// }
/// ```
///
/// With parameter descriptions:
/// ```rust
/// use rig_derive::rig_tool;
///
/// #[rig_tool(
/// description = "A tool that performs string operations",
/// params(
/// text = "The input text to process",
/// operation = "The operation to perform (uppercase, lowercase, reverse)"
/// )
/// )]
/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
/// match operation.as_str() {
/// "uppercase" => Ok(text.to_uppercase()),
/// "lowercase" => Ok(text.to_lowercase()),
/// "reverse" => Ok(text.chars().rev().collect()),
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
/// }
/// }
/// ```
#[proc_macro_attribute]
pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as MacroArgs);
let input_fn = parse_macro_input!(input as syn::ItemFn);
// Extract function details
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
let is_async = input_fn.sig.asyncness.is_some();
// Extract return type and get Output and Error types from Result<T, E>
let return_type = &input_fn.sig.output;
let output_type = match return_type {
ReturnType::Type(_, ty) => {
if let Type::Path(type_path) = ty.deref() {
if let Some(last_segment) = type_path.path.segments.last() {
if last_segment.ident == "Result" {
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
if args.args.len() == 2 {
let output = args.args.first().unwrap();
let error = args.args.last().unwrap();
// Convert the error type to a string for comparison
let error_str = quote!(#error).to_string().replace(" ", "");
if !error_str.contains("rig::tool::ToolError") {
panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
}
quote!(#output)
} else {
panic!("Expected Result with two type parameters");
}
} else {
panic!("Expected angle bracketed type parameters for Result");
}
} else {
panic!("Return type must be a Result");
}
} else {
panic!("Invalid return type");
}
} else {
panic!("Invalid return type");
}
}
_ => panic!("Function must have a return type"),
};
// Generate PascalCase struct name from the function name
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
// Use provided description or generate a default one
let tool_description = match args.description {
Some(desc) => quote! { #desc.to_string() },
None => quote! { format!("Function to {}", Self::NAME) },
};
// Extract parameter names, types, and descriptions
let mut param_names = Vec::new();
let mut param_types = Vec::new();
let mut param_descriptions = Vec::new();
let mut json_types = Vec::new();
for arg in input_fn.sig.inputs.iter() {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
let param_name = &param_ident.ident;
let param_name_str = param_name.to_string();
let ty = &pat_type.ty;
let default_parameter_description = format!("Parameter {}", param_name_str);
let description = args
.param_descriptions
.get(&param_name_str)
.map(|s| s.to_owned())
.unwrap_or(default_parameter_description);
param_names.push(param_name);
param_types.push(ty);
param_descriptions.push(description);
json_types.push(get_json_type(ty));
}
}
}
let params_struct_name = format_ident!("{}Parameters", struct_name);
let static_name = format_ident!("{}", fn_name_str.to_uppercase());
// Generate the call implementation based on whether the function is async
let call_impl = if is_async {
quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
}
}
} else {
quote! {
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
}
}
};
let expanded = quote! {
#[derive(serde::Deserialize)]
struct #params_struct_name {
#(#param_names: #param_types,)*
}
#input_fn
#[derive(Default)]
pub(crate) struct #struct_name;
impl rig::tool::Tool for #struct_name {
const NAME: &'static str = #fn_name_str;
type Args = #params_struct_name;
type Output = #output_type;
type Error = rig::tool::ToolError;
fn name(&self) -> String {
#fn_name_str.to_string()
}
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
let parameters = serde_json::json!({
"type": "object",
"properties": {
#(
stringify!(#param_names): {
#json_types,
"description": #param_descriptions
}
),*
}
});
rig::completion::ToolDefinition {
name: #fn_name_str.to_string(),
description: #tool_description.to_string(),
parameters,
}
}
#call_impl
}
pub static #static_name: #struct_name = #struct_name;
};
TokenStream::from(expanded)
}

View File

@ -0,0 +1,145 @@
use rig::tool::Tool;
use rig_derive::rig_tool;
#[rig_tool(
description = "Perform basic arithmetic operations",
params(
x = "First number in the calculation",
y = "Second number in the calculation",
operation = "The operation to perform (add, subtract, multiply, divide)"
)
)]
async fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
match operation.as_str() {
"add" => Ok(x + y),
"subtract" => Ok(x - y),
"multiply" => Ok(x * y),
"divide" => {
if y == 0 {
Err(rig::tool::ToolError::ToolCallError(
"Division by zero".into(),
))
} else {
Ok(x / y)
}
}
_ => Err(rig::tool::ToolError::ToolCallError(
format!("Unknown operation: {}", operation).into(),
)),
}
}
#[rig_tool(
description = "Perform basic arithmetic operations",
params(
x = "First number in the calculation",
y = "Second number in the calculation",
operation = "The operation to perform (add, subtract, multiply, divide)"
)
)]
fn sync_calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
match operation.as_str() {
"add" => Ok(x + y),
"subtract" => Ok(x - y),
"multiply" => Ok(x * y),
"divide" => {
if y == 0 {
Err(rig::tool::ToolError::ToolCallError(
"Division by zero".into(),
))
} else {
Ok(x / y)
}
}
_ => Err(rig::tool::ToolError::ToolCallError(
format!("Unknown operation: {}", operation).into(),
)),
}
}
#[tokio::test]
async fn test_calculator_tool() {
// Create an instance of our tool
let calculator = Calculator::default();
// Test tool information
let definition = calculator.definition(String::default()).await;
println!("{:?}", definition);
assert_eq!(calculator.name(), "calculator");
assert_eq!(
definition.description,
"Perform basic arithmetic operations"
);
// Test valid operations
let test_cases = vec![
(
CalculatorParameters {
x: 5,
y: 3,
operation: "add".to_string(),
},
8,
),
(
CalculatorParameters {
x: 5,
y: 3,
operation: "subtract".to_string(),
},
2,
),
(
CalculatorParameters {
x: 5,
y: 3,
operation: "multiply".to_string(),
},
15,
),
(
CalculatorParameters {
x: 6,
y: 2,
operation: "divide".to_string(),
},
3,
),
];
for (input, expected) in test_cases {
let result = calculator.call(input).await.unwrap();
assert_eq!(result, serde_json::json!(expected));
}
// Test division by zero
let div_zero = CalculatorParameters {
x: 5,
y: 0,
operation: "divide".to_string(),
};
let err = calculator.call(div_zero).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test invalid operation
let invalid_op = CalculatorParameters {
x: 5,
y: 3,
operation: "power".to_string(),
};
let err = calculator.call(invalid_op).await.unwrap_err();
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
// Test sync calculator
let sync_calculator = SyncCalculator::default();
let result = sync_calculator
.call(SyncCalculatorParameters {
x: 5,
y: 3,
operation: "add".to_string(),
})
.await
.unwrap();
assert_eq!(result, serde_json::json!(8));
}

View File

@ -60,6 +60,35 @@ pub trait EmbeddingModel: Clone + Sync + Send {
} }
} }
/// Trait for embedding models that can generate embeddings for images.
pub trait ImageEmbeddingModel: Clone + Sync + Send {
/// The maximum number of images that can be embedded in a single request.
const MAX_DOCUMENTS: usize;
/// The number of dimensions in the embedding vector.
fn ndims(&self) -> usize;
/// Embed multiple images in a single request from bytes.
fn embed_images(
&self,
images: impl IntoIterator<Item = Vec<u8>> + Send,
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
/// Embed a single image from bytes.
fn embed_image<'a>(
&'a self,
bytes: &'a [u8],
) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
async move {
Ok(self
.embed_images(vec![bytes.to_owned()])
.await?
.pop()
.expect("There should be at least one embedding"))
}
}
}
/// Struct that holds a single document and its embedding. /// Struct that holds a single document and its embedding.
#[derive(Clone, Default, Deserialize, Serialize, Debug)] #[derive(Clone, Default, Deserialize, Serialize, Debug)]
pub struct Embedding { pub struct Embedding {

View File

@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
} }
} }
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
#[cfg_attr(feature = "worker", worker::send)] #[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts( async fn embed_texts(
&self, &self,
documents: impl IntoIterator<Item = String> + Send, documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect(); let documents: Vec<String> = documents.into_iter().collect();
let mut request_body = json!({
"model": format!("models/{}", self.model),
"content": {
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
},
});
if let Some(ndims) = self.ndims { // Google batch embed requests. See docstrings for API ref link.
request_body["output_dimensionality"] = json!(ndims); let requests: Vec<_> = documents
} .iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
let response = self let response = self
.client .client
.post(&format!("/v1beta/models/{}:embedContent", self.model)) .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
.json(&request_body) .json(&request_body)
.send() .send()
.await? .await?
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
match response { match response {
ApiResponse::Ok(response) => { ApiResponse::Ok(response) => {
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); let docs = documents
Ok(documents
.into_iter() .into_iter()
.zip(response.embedding.values.chunks(chunk_size)) .zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding { .map(|(document, embedding)| embeddings::Embedding {
document, document,
vec: embedding.to_vec(), vec: embedding.values,
}) })
.collect()) .collect();
Ok(docs)
} }
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
} }
@ -196,7 +208,7 @@ mod gemini_api_types {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub embedding: EmbeddingValues, pub embeddings: Vec<EmbeddingValues>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

View File

@ -1,6 +1,6 @@
// ================================================================ // ================================================================
//! xAI Completion Integration //! xAI Completion Integration
//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions) //! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
// ================================================================ // ================================================================
use crate::{ use crate::{

View File

@ -204,6 +204,28 @@ where
} }
} }
#[cfg(feature = "mcp")]
impl From<&mcp_core::types::Tool> for ToolDefinition {
fn from(val: &mcp_core::types::Tool) -> Self {
Self {
name: val.name.to_owned(),
description: val.description.to_owned().unwrap_or_default(),
parameters: val.input_schema.to_owned(),
}
}
}
#[cfg(feature = "mcp")]
impl From<mcp_core::types::Tool> for ToolDefinition {
fn from(val: mcp_core::types::Tool) -> Self {
Self {
name: val.name,
description: val.description.unwrap_or_default(),
parameters: val.input_schema,
}
}
}
#[cfg(feature = "mcp")] #[cfg(feature = "mcp")]
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("MCP tool error: {0}")] #[error("MCP tool error: {0}")]

View File

@ -17,6 +17,9 @@
## Rig-Lancedb ## Rig-Lancedb
This companion crate implements a Rig vector store based on Lancedb. This companion crate implements a Rig vector store based on Lancedb.
## Pre-requisites
If you are using `rig-lancedb` locally, you must ensure you have `protoc` (the [Protobuf Compiler](https://protobuf.dev/installation/)) installed.
## Usage ## Usage
Add the companion crate to your `Cargo.toml`, along with the rig-core crate: Add the companion crate to your `Cargo.toml`, along with the rig-core crate:

View File

@ -38,8 +38,15 @@ async fn main() -> Result<(), anyhow::Error> {
.build() .build()
.await?; .await?;
let table = db let table = if db
.create_table( .table_names()
.execute()
.await?
.contains(&"definitions".to_string())
{
db.open_table("definitions").execute().await?
} else {
db.create_table(
"definitions", "definitions",
RecordBatchIterator::new( RecordBatchIterator::new(
vec![as_record_batch(embeddings, model.ndims())], vec![as_record_batch(embeddings, model.ndims())],
@ -47,16 +54,19 @@ async fn main() -> Result<(), anyhow::Error> {
), ),
) )
.execute() .execute()
.await?; .await?
};
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
table if table.index_stats("embedding").await?.is_none() {
.create_index( table
&["embedding"], .create_index(
lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), &["embedding"],
) lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()),
.execute() )
.await?; .execute()
.await?;
}
// Define search_params params that will be used by the vector store to perform the vector search. // Define search_params params that will be used by the vector store to perform the vector search.
let search_params = SearchParams::default(); let search_params = SearchParams::default();

View File

@ -32,8 +32,15 @@ async fn main() -> Result<(), anyhow::Error> {
// Initialize LanceDB locally. // Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?; let db = lancedb::connect("data/lancedb-store").execute().await?;
let table = db let table = if db
.create_table( .table_names()
.execute()
.await?
.contains(&"definitions".to_string())
{
db.open_table("definitions").execute().await?
} else {
db.create_table(
"definitions", "definitions",
RecordBatchIterator::new( RecordBatchIterator::new(
vec![as_record_batch(embeddings, model.ndims())], vec![as_record_batch(embeddings, model.ndims())],
@ -41,7 +48,8 @@ async fn main() -> Result<(), anyhow::Error> {
), ),
) )
.execute() .execute()
.await?; .await?
};
let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?;

View File

@ -44,8 +44,15 @@ async fn main() -> Result<(), anyhow::Error> {
.build() .build()
.await?; .await?;
let table = db let table = if db
.create_table( .table_names()
.execute()
.await?
.contains(&"definitions".to_string())
{
db.open_table("definitions").execute().await?
} else {
db.create_table(
"definitions", "definitions",
RecordBatchIterator::new( RecordBatchIterator::new(
vec![as_record_batch(embeddings, model.ndims())], vec![as_record_batch(embeddings, model.ndims())],
@ -53,21 +60,24 @@ async fn main() -> Result<(), anyhow::Error> {
), ),
) )
.execute() .execute()
.await?; .await?
};
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
table if table.index_stats("embedding").await?.is_none() {
.create_index( table
&["embedding"], .create_index(
lancedb::index::Index::IvfPq( &["embedding"],
IvfPqIndexBuilder::default() lancedb::index::Index::IvfPq(
// This overrides the default distance type of L2. IvfPqIndexBuilder::default()
// Needs to be the same distance type as the one used in search params. // This overrides the default distance type of L2.
.distance_type(DistanceType::Cosine), // Needs to be the same distance type as the one used in search params.
), .distance_type(DistanceType::Cosine),
) ),
.execute() )
.await?; .execute()
.await?;
}
// Define search_params params that will be used by the vector store to perform the vector search. // Define search_params params that will be used by the vector store to perform the vector search.
let search_params = SearchParams::default().distance_type(DistanceType::Cosine); let search_params = SearchParams::default().distance_type(DistanceType::Cosine);