mirror of https://github.com/0xplaygrounds/rig
Merge branch 'main' of https://github.com/0xPlaygrounds/rig into fix/multiple-tool-calling
This commit is contained in:
commit
368f9df033
|
@ -8699,10 +8699,16 @@ dependencies = [
|
||||||
name = "rig-derive"
|
name = "rig-derive"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"convert_case 0.6.0",
|
||||||
"indoc",
|
"indoc",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
"rig-core 0.11.0",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"syn 2.0.100",
|
"syn 2.0.100",
|
||||||
|
"tokio",
|
||||||
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -7,10 +7,35 @@ description = "Internal crate that implements Rig derive macros."
|
||||||
repository = "https://github.com/0xPlaygrounds/rig"
|
repository = "https://github.com/0xPlaygrounds/rig"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
convert_case = { version = "0.6.0" }
|
||||||
indoc = "2.0.5"
|
indoc = "2.0.5"
|
||||||
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
|
||||||
quote = "1.0.37"
|
quote = "1.0.37"
|
||||||
|
serde_json = "1.0.108"
|
||||||
syn = { version = "2.0.79", features = ["full"]}
|
syn = { version = "2.0.79", features = ["full"]}
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rig-core = { path = "../../rig-core" }
|
||||||
|
serde = "1.0"
|
||||||
|
serde_json = "1.0.108"
|
||||||
|
tokio = { version = "1.44.0", features = ["full"] }
|
||||||
|
tracing-subscriber = "0.3.0"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "simple"
|
||||||
|
path = "examples/rig_tool/simple.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "with_description"
|
||||||
|
path = "examples/rig_tool/with_description.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "full"
|
||||||
|
path = "examples/rig_tool/full.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "async_tool"
|
||||||
|
path = "examples/rig_tool/async_tool.rs"
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example demonstrating async tool usage
|
||||||
|
#[rig_tool(
|
||||||
|
description = "A tool that simulates an async operation",
|
||||||
|
params(
|
||||||
|
input = "Input value to process",
|
||||||
|
delay_ms = "Delay in milliseconds before returning result"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn async_operation(input: String, delay_ms: u64) -> Result<String, rig::tool::ToolError> {
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||||
|
|
||||||
|
Ok(format!(
|
||||||
|
"Processed after {}ms: {}",
|
||||||
|
delay_ms,
|
||||||
|
input.to_uppercase()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let async_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(AsyncOperation)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"ASYNCOPERATION: {}",
|
||||||
|
serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Process the text 'hello world' with a delay of 1000ms",
|
||||||
|
"Process the text 'async operation' with a delay of 500ms",
|
||||||
|
"Process the text 'concurrent calls' with a delay of 200ms",
|
||||||
|
"Process the text 'error handling' with a delay of 'not a number'",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", async_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example with full attributes including parameter descriptions
|
||||||
|
#[rig_tool(
|
||||||
|
description = "A tool that performs string operations",
|
||||||
|
params(
|
||||||
|
text = "The input text to process",
|
||||||
|
operation = "The operation to perform (uppercase, lowercase, reverse)",
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||||
|
let result = match operation.as_str() {
|
||||||
|
"uppercase" => text.to_uppercase(),
|
||||||
|
"lowercase" => text.to_lowercase(),
|
||||||
|
"reverse" => text.chars().rev().collect(),
|
||||||
|
_ => {
|
||||||
|
return Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let string_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(StringProcessor)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"STRINGPROCESSOR: {}",
|
||||||
|
serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Convert 'hello world' to uppercase",
|
||||||
|
"Convert 'HELLO WORLD' to lowercase",
|
||||||
|
"Reverse the string 'hello world'",
|
||||||
|
"Convert 'hello world' to uppercase and repeat it 3 times",
|
||||||
|
"Perform an invalid operation on 'hello world'",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", string_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Simple example with no attributes
|
||||||
|
#[rig_tool]
|
||||||
|
fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn subtract(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a - b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn multiply(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
Ok(a * b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn divide(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
if b == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(a / b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn answer_secret_question() -> Result<(bool, bool, bool, bool, bool), rig::tool::ToolError> {
|
||||||
|
Ok((false, false, true, false, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn how_many_rs(s: String) -> Result<usize, rig::tool::ToolError> {
|
||||||
|
Ok(s.chars()
|
||||||
|
.filter(|c| *c == 'r' || *c == 'R')
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool]
|
||||||
|
fn sum_numbers(numbers: Vec<i64>) -> Result<i64, rig::tool::ToolError> {
|
||||||
|
Ok(numbers.iter().sum())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let calculator_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Add)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Calculate 5 + 3",
|
||||||
|
"What is 10 + 20?",
|
||||||
|
"Add 100 and 200",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
use rig::completion::Prompt;
|
||||||
|
use rig::providers;
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
use tracing_subscriber;
|
||||||
|
|
||||||
|
// Example with description attribute
|
||||||
|
#[rig_tool(description = "Perform basic arithmetic operations")]
|
||||||
|
fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt().pretty().init();
|
||||||
|
|
||||||
|
let calculator_agent = providers::openai::Client::from_env()
|
||||||
|
.agent(providers::openai::GPT_4O)
|
||||||
|
.preamble("You are an agent with tools access, always use the tools")
|
||||||
|
.max_tokens(1024)
|
||||||
|
.tool(Calculator)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("Tool definition:");
|
||||||
|
println!(
|
||||||
|
"CALCULATOR: {}",
|
||||||
|
serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
for prompt in [
|
||||||
|
"What tools do you have?",
|
||||||
|
"Calculate 5 + 3",
|
||||||
|
"What is 10 - 4?",
|
||||||
|
"Multiply 6 and 7",
|
||||||
|
"Divide 20 by 5",
|
||||||
|
"What is 10 / 0?",
|
||||||
|
] {
|
||||||
|
println!("User: {}", prompt);
|
||||||
|
println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,15 @@
|
||||||
extern crate proc_macro;
|
extern crate proc_macro;
|
||||||
|
|
||||||
|
use convert_case::{Case, Casing};
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use syn::{parse_macro_input, DeriveInput};
|
use quote::{format_ident, quote};
|
||||||
|
use std::{collections::HashMap, ops::Deref};
|
||||||
|
use syn::{
|
||||||
|
parse::{Parse, ParseStream},
|
||||||
|
parse_macro_input,
|
||||||
|
punctuated::Punctuated,
|
||||||
|
DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
|
||||||
|
};
|
||||||
|
|
||||||
mod basic;
|
mod basic;
|
||||||
mod custom;
|
mod custom;
|
||||||
|
@ -19,3 +28,311 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
|
||||||
.unwrap_or_else(syn::Error::into_compile_error)
|
.unwrap_or_else(syn::Error::into_compile_error)
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MacroArgs {
|
||||||
|
description: Option<String>,
|
||||||
|
param_descriptions: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for MacroArgs {
|
||||||
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||||
|
let mut description = None;
|
||||||
|
let mut param_descriptions = HashMap::new();
|
||||||
|
|
||||||
|
// If the input is empty, return default values
|
||||||
|
if input.is_empty() {
|
||||||
|
return Ok(MacroArgs {
|
||||||
|
description,
|
||||||
|
param_descriptions,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
|
||||||
|
|
||||||
|
for meta in meta_list {
|
||||||
|
match meta {
|
||||||
|
Meta::NameValue(nv) => {
|
||||||
|
let ident = nv.path.get_ident().unwrap().to_string();
|
||||||
|
if let Expr::Lit(ExprLit {
|
||||||
|
lit: Lit::Str(lit_str),
|
||||||
|
..
|
||||||
|
}) = nv.value
|
||||||
|
{
|
||||||
|
if ident.as_str() == "description" {
|
||||||
|
description = Some(lit_str.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Meta::List(list) if list.path.is_ident("params") => {
|
||||||
|
let nested: Punctuated<Meta, Token![,]> =
|
||||||
|
list.parse_args_with(Punctuated::parse_terminated)?;
|
||||||
|
|
||||||
|
for meta in nested {
|
||||||
|
if let Meta::NameValue(nv) = meta {
|
||||||
|
if let Expr::Lit(ExprLit {
|
||||||
|
lit: Lit::Str(lit_str),
|
||||||
|
..
|
||||||
|
}) = nv.value
|
||||||
|
{
|
||||||
|
let param_name = nv.path.get_ident().unwrap().to_string();
|
||||||
|
param_descriptions.insert(param_name, lit_str.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(MacroArgs {
|
||||||
|
description,
|
||||||
|
param_descriptions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
|
||||||
|
match ty {
|
||||||
|
Type::Path(type_path) => {
|
||||||
|
let segment = &type_path.path.segments[0];
|
||||||
|
let type_name = segment.ident.to_string();
|
||||||
|
|
||||||
|
// Handle Vec types
|
||||||
|
if type_name == "Vec" {
|
||||||
|
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
|
||||||
|
if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
|
||||||
|
let inner_json_type = get_json_type(inner_type);
|
||||||
|
return quote! {
|
||||||
|
"type": "array",
|
||||||
|
"items": { #inner_json_type }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quote! { "type": "array" };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle primitive types
|
||||||
|
match type_name.as_str() {
|
||||||
|
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
|
||||||
|
quote! { "type": "number" }
|
||||||
|
}
|
||||||
|
"String" | "str" => {
|
||||||
|
quote! { "type": "string" }
|
||||||
|
}
|
||||||
|
"bool" => {
|
||||||
|
quote! { "type": "boolean" }
|
||||||
|
}
|
||||||
|
// Handle other types as objects
|
||||||
|
_ => {
|
||||||
|
quote! { "type": "object" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
quote! { "type": "object" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// Basic usage:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool]
|
||||||
|
/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
/// Ok(a + b)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// With description:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool(description = "Perform basic arithmetic operations")]
|
||||||
|
/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
/// match operation.as_str() {
|
||||||
|
/// "add" => Ok(x + y),
|
||||||
|
/// "subtract" => Ok(x - y),
|
||||||
|
/// "multiply" => Ok(x * y),
|
||||||
|
/// "divide" => Ok(x / y),
|
||||||
|
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// With parameter descriptions:
|
||||||
|
/// ```rust
|
||||||
|
/// use rig_derive::rig_tool;
|
||||||
|
///
|
||||||
|
/// #[rig_tool(
|
||||||
|
/// description = "A tool that performs string operations",
|
||||||
|
/// params(
|
||||||
|
/// text = "The input text to process",
|
||||||
|
/// operation = "The operation to perform (uppercase, lowercase, reverse)"
|
||||||
|
/// )
|
||||||
|
/// )]
|
||||||
|
/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
|
||||||
|
/// match operation.as_str() {
|
||||||
|
/// "uppercase" => Ok(text.to_uppercase()),
|
||||||
|
/// "lowercase" => Ok(text.to_lowercase()),
|
||||||
|
/// "reverse" => Ok(text.chars().rev().collect()),
|
||||||
|
/// _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
|
let args = parse_macro_input!(args as MacroArgs);
|
||||||
|
let input_fn = parse_macro_input!(input as syn::ItemFn);
|
||||||
|
|
||||||
|
// Extract function details
|
||||||
|
let fn_name = &input_fn.sig.ident;
|
||||||
|
let fn_name_str = fn_name.to_string();
|
||||||
|
let is_async = input_fn.sig.asyncness.is_some();
|
||||||
|
|
||||||
|
// Extract return type and get Output and Error types from Result<T, E>
|
||||||
|
let return_type = &input_fn.sig.output;
|
||||||
|
let output_type = match return_type {
|
||||||
|
ReturnType::Type(_, ty) => {
|
||||||
|
if let Type::Path(type_path) = ty.deref() {
|
||||||
|
if let Some(last_segment) = type_path.path.segments.last() {
|
||||||
|
if last_segment.ident == "Result" {
|
||||||
|
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
|
||||||
|
if args.args.len() == 2 {
|
||||||
|
let output = args.args.first().unwrap();
|
||||||
|
let error = args.args.last().unwrap();
|
||||||
|
|
||||||
|
// Convert the error type to a string for comparison
|
||||||
|
let error_str = quote!(#error).to_string().replace(" ", "");
|
||||||
|
if !error_str.contains("rig::tool::ToolError") {
|
||||||
|
panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
quote!(#output)
|
||||||
|
} else {
|
||||||
|
panic!("Expected Result with two type parameters");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Expected angle bracketed type parameters for Result");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Return type must be a Result");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Invalid return type");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("Invalid return type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Function must have a return type"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate PascalCase struct name from the function name
|
||||||
|
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
|
||||||
|
|
||||||
|
// Use provided description or generate a default one
|
||||||
|
let tool_description = match args.description {
|
||||||
|
Some(desc) => quote! { #desc.to_string() },
|
||||||
|
None => quote! { format!("Function to {}", Self::NAME) },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract parameter names, types, and descriptions
|
||||||
|
let mut param_names = Vec::new();
|
||||||
|
let mut param_types = Vec::new();
|
||||||
|
let mut param_descriptions = Vec::new();
|
||||||
|
let mut json_types = Vec::new();
|
||||||
|
|
||||||
|
for arg in input_fn.sig.inputs.iter() {
|
||||||
|
if let syn::FnArg::Typed(pat_type) = arg {
|
||||||
|
if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
|
||||||
|
let param_name = ¶m_ident.ident;
|
||||||
|
let param_name_str = param_name.to_string();
|
||||||
|
let ty = &pat_type.ty;
|
||||||
|
let default_parameter_description = format!("Parameter {}", param_name_str);
|
||||||
|
let description = args
|
||||||
|
.param_descriptions
|
||||||
|
.get(¶m_name_str)
|
||||||
|
.map(|s| s.to_owned())
|
||||||
|
.unwrap_or(default_parameter_description);
|
||||||
|
|
||||||
|
param_names.push(param_name);
|
||||||
|
param_types.push(ty);
|
||||||
|
param_descriptions.push(description);
|
||||||
|
json_types.push(get_json_type(ty));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let params_struct_name = format_ident!("{}Parameters", struct_name);
|
||||||
|
let static_name = format_ident!("{}", fn_name_str.to_uppercase());
|
||||||
|
|
||||||
|
// Generate the call implementation based on whether the function is async
|
||||||
|
let call_impl = if is_async {
|
||||||
|
quote! {
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
#fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
#fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let expanded = quote! {
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct #params_struct_name {
|
||||||
|
#(#param_names: #param_types,)*
|
||||||
|
}
|
||||||
|
|
||||||
|
#input_fn
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub(crate) struct #struct_name;
|
||||||
|
|
||||||
|
impl rig::tool::Tool for #struct_name {
|
||||||
|
const NAME: &'static str = #fn_name_str;
|
||||||
|
|
||||||
|
type Args = #params_struct_name;
|
||||||
|
type Output = #output_type;
|
||||||
|
type Error = rig::tool::ToolError;
|
||||||
|
|
||||||
|
fn name(&self) -> String {
|
||||||
|
#fn_name_str.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
|
||||||
|
let parameters = serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
#(
|
||||||
|
stringify!(#param_names): {
|
||||||
|
#json_types,
|
||||||
|
"description": #param_descriptions
|
||||||
|
}
|
||||||
|
),*
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
rig::completion::ToolDefinition {
|
||||||
|
name: #fn_name_str.to_string(),
|
||||||
|
description: #tool_description.to_string(),
|
||||||
|
parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#call_impl
|
||||||
|
}
|
||||||
|
|
||||||
|
pub static #static_name: #struct_name = #struct_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
TokenStream::from(expanded)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
use rig::tool::Tool;
|
||||||
|
use rig_derive::rig_tool;
|
||||||
|
|
||||||
|
#[rig_tool(
|
||||||
|
description = "Perform basic arithmetic operations",
|
||||||
|
params(
|
||||||
|
x = "First number in the calculation",
|
||||||
|
y = "Second number in the calculation",
|
||||||
|
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rig_tool(
|
||||||
|
description = "Perform basic arithmetic operations",
|
||||||
|
params(
|
||||||
|
x = "First number in the calculation",
|
||||||
|
y = "Second number in the calculation",
|
||||||
|
operation = "The operation to perform (add, subtract, multiply, divide)"
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
fn sync_calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
|
||||||
|
match operation.as_str() {
|
||||||
|
"add" => Ok(x + y),
|
||||||
|
"subtract" => Ok(x - y),
|
||||||
|
"multiply" => Ok(x * y),
|
||||||
|
"divide" => {
|
||||||
|
if y == 0 {
|
||||||
|
Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
"Division by zero".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(x / y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(rig::tool::ToolError::ToolCallError(
|
||||||
|
format!("Unknown operation: {}", operation).into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_calculator_tool() {
|
||||||
|
// Create an instance of our tool
|
||||||
|
let calculator = Calculator::default();
|
||||||
|
|
||||||
|
// Test tool information
|
||||||
|
let definition = calculator.definition(String::default()).await;
|
||||||
|
println!("{:?}", definition);
|
||||||
|
assert_eq!(calculator.name(), "calculator");
|
||||||
|
assert_eq!(
|
||||||
|
definition.description,
|
||||||
|
"Perform basic arithmetic operations"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test valid operations
|
||||||
|
let test_cases = vec![
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "add".to_string(),
|
||||||
|
},
|
||||||
|
8,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "subtract".to_string(),
|
||||||
|
},
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "multiply".to_string(),
|
||||||
|
},
|
||||||
|
15,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CalculatorParameters {
|
||||||
|
x: 6,
|
||||||
|
y: 2,
|
||||||
|
operation: "divide".to_string(),
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (input, expected) in test_cases {
|
||||||
|
let result = calculator.call(input).await.unwrap();
|
||||||
|
assert_eq!(result, serde_json::json!(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test division by zero
|
||||||
|
let div_zero = CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 0,
|
||||||
|
operation: "divide".to_string(),
|
||||||
|
};
|
||||||
|
let err = calculator.call(div_zero).await.unwrap_err();
|
||||||
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
|
// Test invalid operation
|
||||||
|
let invalid_op = CalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "power".to_string(),
|
||||||
|
};
|
||||||
|
let err = calculator.call(invalid_op).await.unwrap_err();
|
||||||
|
assert!(matches!(err, rig::tool::ToolError::ToolCallError(_)));
|
||||||
|
|
||||||
|
// Test sync calculator
|
||||||
|
let sync_calculator = SyncCalculator::default();
|
||||||
|
let result = sync_calculator
|
||||||
|
.call(SyncCalculatorParameters {
|
||||||
|
x: 5,
|
||||||
|
y: 3,
|
||||||
|
operation: "add".to_string(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result, serde_json::json!(8));
|
||||||
|
}
|
|
@ -60,6 +60,35 @@ pub trait EmbeddingModel: Clone + Sync + Send {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trait for embedding models that can generate embeddings for images.
|
||||||
|
pub trait ImageEmbeddingModel: Clone + Sync + Send {
|
||||||
|
/// The maximum number of images that can be embedded in a single request.
|
||||||
|
const MAX_DOCUMENTS: usize;
|
||||||
|
|
||||||
|
/// The number of dimensions in the embedding vector.
|
||||||
|
fn ndims(&self) -> usize;
|
||||||
|
|
||||||
|
/// Embed multiple images in a single request from bytes.
|
||||||
|
fn embed_images(
|
||||||
|
&self,
|
||||||
|
images: impl IntoIterator<Item = Vec<u8>> + Send,
|
||||||
|
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
|
||||||
|
|
||||||
|
/// Embed a single image from bytes.
|
||||||
|
fn embed_image<'a>(
|
||||||
|
&'a self,
|
||||||
|
bytes: &'a [u8],
|
||||||
|
) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
|
||||||
|
async move {
|
||||||
|
Ok(self
|
||||||
|
.embed_images(vec![bytes.to_owned()])
|
||||||
|
.await?
|
||||||
|
.pop()
|
||||||
|
.expect("There should be at least one embedding"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Struct that holds a single document and its embedding.
|
/// Struct that holds a single document and its embedding.
|
||||||
#[derive(Clone, Default, Deserialize, Serialize, Debug)]
|
#[derive(Clone, Default, Deserialize, Serialize, Debug)]
|
||||||
pub struct Embedding {
|
pub struct Embedding {
|
||||||
|
|
|
@ -41,26 +41,37 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
|
||||||
#[cfg_attr(feature = "worker", worker::send)]
|
#[cfg_attr(feature = "worker", worker::send)]
|
||||||
async fn embed_texts(
|
async fn embed_texts(
|
||||||
&self,
|
&self,
|
||||||
documents: impl IntoIterator<Item = String> + Send,
|
documents: impl IntoIterator<Item = String> + Send,
|
||||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||||
let documents: Vec<_> = documents.into_iter().collect();
|
let documents: Vec<String> = documents.into_iter().collect();
|
||||||
let mut request_body = json!({
|
|
||||||
"model": format!("models/{}", self.model),
|
|
||||||
"content": {
|
|
||||||
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(ndims) = self.ndims {
|
// Google batch embed requests. See docstrings for API ref link.
|
||||||
request_body["output_dimensionality"] = json!(ndims);
|
let requests: Vec<_> = documents
|
||||||
}
|
.iter()
|
||||||
|
.map(|doc| {
|
||||||
|
json!({
|
||||||
|
"model": format!("models/{}", self.model),
|
||||||
|
"content": json!({
|
||||||
|
"parts": [json!({
|
||||||
|
"text": doc.to_string()
|
||||||
|
})]
|
||||||
|
}),
|
||||||
|
"output_dimensionality": self.ndims,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request_body = json!({ "requests": requests });
|
||||||
|
|
||||||
|
println!("{}", serde_json::to_string_pretty(&request_body).unwrap());
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("/v1beta/models/{}:embedContent", self.model))
|
.post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
|
||||||
.json(&request_body)
|
.json(&request_body)
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
|
@ -70,15 +81,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
ApiResponse::Ok(response) => {
|
ApiResponse::Ok(response) => {
|
||||||
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
|
let docs = documents
|
||||||
Ok(documents
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(response.embedding.values.chunks(chunk_size))
|
.zip(response.embeddings)
|
||||||
.map(|(document, embedding)| embeddings::Embedding {
|
.map(|(document, embedding)| embeddings::Embedding {
|
||||||
document,
|
document,
|
||||||
vec: embedding.to_vec(),
|
vec: embedding.values,
|
||||||
})
|
})
|
||||||
.collect())
|
.collect();
|
||||||
|
|
||||||
|
Ok(docs)
|
||||||
}
|
}
|
||||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||||
}
|
}
|
||||||
|
@ -196,7 +208,7 @@ mod gemini_api_types {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub embedding: EmbeddingValues,
|
pub embeddings: Vec<EmbeddingValues>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// ================================================================
|
// ================================================================
|
||||||
//! xAI Completion Integration
|
//! xAI Completion Integration
|
||||||
//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions)
|
//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
|
||||||
// ================================================================
|
// ================================================================
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
|
@ -204,6 +204,28 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
impl From<&mcp_core::types::Tool> for ToolDefinition {
|
||||||
|
fn from(val: &mcp_core::types::Tool) -> Self {
|
||||||
|
Self {
|
||||||
|
name: val.name.to_owned(),
|
||||||
|
description: val.description.to_owned().unwrap_or_default(),
|
||||||
|
parameters: val.input_schema.to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mcp")]
|
||||||
|
impl From<mcp_core::types::Tool> for ToolDefinition {
|
||||||
|
fn from(val: mcp_core::types::Tool) -> Self {
|
||||||
|
Self {
|
||||||
|
name: val.name,
|
||||||
|
description: val.description.unwrap_or_default(),
|
||||||
|
parameters: val.input_schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
#[cfg(feature = "mcp")]
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("MCP tool error: {0}")]
|
#[error("MCP tool error: {0}")]
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
## Rig-Lancedb
|
## Rig-Lancedb
|
||||||
This companion crate implements a Rig vector store based on Lancedb.
|
This companion crate implements a Rig vector store based on Lancedb.
|
||||||
|
|
||||||
|
## Pre-requisites
|
||||||
|
If you are using `rig-lancedb` locally, you must ensure you have `protoc` (the [Protobuf Compiler](https://protobuf.dev/installation/)) installed.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Add the companion crate to your `Cargo.toml`, along with the rig-core crate:
|
Add the companion crate to your `Cargo.toml`, along with the rig-core crate:
|
||||||
|
|
|
@ -38,8 +38,15 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
.build()
|
.build()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let table = db
|
let table = if db
|
||||||
.create_table(
|
.table_names()
|
||||||
|
.execute()
|
||||||
|
.await?
|
||||||
|
.contains(&"definitions".to_string())
|
||||||
|
{
|
||||||
|
db.open_table("definitions").execute().await?
|
||||||
|
} else {
|
||||||
|
db.create_table(
|
||||||
"definitions",
|
"definitions",
|
||||||
RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![as_record_batch(embeddings, model.ndims())],
|
vec![as_record_batch(embeddings, model.ndims())],
|
||||||
|
@ -47,16 +54,19 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.execute()
|
.execute()
|
||||||
.await?;
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
|
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
|
||||||
table
|
if table.index_stats("embedding").await?.is_none() {
|
||||||
.create_index(
|
table
|
||||||
&["embedding"],
|
.create_index(
|
||||||
lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()),
|
&["embedding"],
|
||||||
)
|
lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()),
|
||||||
.execute()
|
)
|
||||||
.await?;
|
.execute()
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Define search_params params that will be used by the vector store to perform the vector search.
|
// Define search_params params that will be used by the vector store to perform the vector search.
|
||||||
let search_params = SearchParams::default();
|
let search_params = SearchParams::default();
|
||||||
|
|
|
@ -32,8 +32,15 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
// Initialize LanceDB locally.
|
// Initialize LanceDB locally.
|
||||||
let db = lancedb::connect("data/lancedb-store").execute().await?;
|
let db = lancedb::connect("data/lancedb-store").execute().await?;
|
||||||
|
|
||||||
let table = db
|
let table = if db
|
||||||
.create_table(
|
.table_names()
|
||||||
|
.execute()
|
||||||
|
.await?
|
||||||
|
.contains(&"definitions".to_string())
|
||||||
|
{
|
||||||
|
db.open_table("definitions").execute().await?
|
||||||
|
} else {
|
||||||
|
db.create_table(
|
||||||
"definitions",
|
"definitions",
|
||||||
RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![as_record_batch(embeddings, model.ndims())],
|
vec![as_record_batch(embeddings, model.ndims())],
|
||||||
|
@ -41,7 +48,8 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.execute()
|
.execute()
|
||||||
.await?;
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?;
|
let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?;
|
||||||
|
|
||||||
|
|
|
@ -44,8 +44,15 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
.build()
|
.build()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let table = db
|
let table = if db
|
||||||
.create_table(
|
.table_names()
|
||||||
|
.execute()
|
||||||
|
.await?
|
||||||
|
.contains(&"definitions".to_string())
|
||||||
|
{
|
||||||
|
db.open_table("definitions").execute().await?
|
||||||
|
} else {
|
||||||
|
db.create_table(
|
||||||
"definitions",
|
"definitions",
|
||||||
RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![as_record_batch(embeddings, model.ndims())],
|
vec![as_record_batch(embeddings, model.ndims())],
|
||||||
|
@ -53,21 +60,24 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.execute()
|
.execute()
|
||||||
.await?;
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
|
// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
|
||||||
table
|
if table.index_stats("embedding").await?.is_none() {
|
||||||
.create_index(
|
table
|
||||||
&["embedding"],
|
.create_index(
|
||||||
lancedb::index::Index::IvfPq(
|
&["embedding"],
|
||||||
IvfPqIndexBuilder::default()
|
lancedb::index::Index::IvfPq(
|
||||||
// This overrides the default distance type of L2.
|
IvfPqIndexBuilder::default()
|
||||||
// Needs to be the same distance type as the one used in search params.
|
// This overrides the default distance type of L2.
|
||||||
.distance_type(DistanceType::Cosine),
|
// Needs to be the same distance type as the one used in search params.
|
||||||
),
|
.distance_type(DistanceType::Cosine),
|
||||||
)
|
),
|
||||||
.execute()
|
)
|
||||||
.await?;
|
.execute()
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Define search_params params that will be used by the vector store to perform the vector search.
|
// Define search_params params that will be used by the vector store to perform the vector search.
|
||||||
let search_params = SearchParams::default().distance_type(DistanceType::Cosine);
|
let search_params = SearchParams::default().distance_type(DistanceType::Cosine);
|
||||||
|
|
Loading…
Reference in New Issue