diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 300c962..21363b3 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -94,6 +94,7 @@ pub mod loaders; pub mod one_or_many; pub mod pipeline; pub mod providers; +pub mod semantic_routing; pub mod streaming; pub mod tool; pub mod transcription; diff --git a/rig-core/src/semantic_routing.rs b/rig-core/src/semantic_routing.rs new file mode 100644 index 0000000..76c19f5 --- /dev/null +++ b/rig-core/src/semantic_routing.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{embeddings::EmbeddingModel, vector_store::VectorStoreIndex}; + +pub struct SemanticRouter { + store: V, + threshold: f64, +} + +impl SemanticRouter { + pub fn builder() -> SemanticRouterBuilder { + SemanticRouterBuilder::new() + } +} + +impl SemanticRouter +where + V: VectorStoreIndex, +{ + pub async fn select_route(&self, query: &str) -> Option { + let res = self.store.top_n(query, 1).await.ok()?; + let (score, _, SemanticRoute { tag }) = res.first()?; + + if *score < self.threshold { + return None; + } + + Some(tag.to_owned()) + } +} + +#[derive(Serialize, Deserialize)] +pub struct SemanticRoute { + tag: String, +} + +pub trait Router: VectorStoreIndex { + fn retrieve_route() -> impl std::future::Future> + Send; +} + +pub struct SemanticRouterBuilder { + store: Option, + threshold: Option, +} + +impl SemanticRouterBuilder { + pub fn new() -> Self { + Self { + store: None, + threshold: None, + } + } + + pub fn store(mut self, router: V) -> Self { + self.store = Some(router); + + self + } + + pub fn threshold(mut self, threshold: f64) -> Self { + self.threshold = Some(threshold); + + self + } + + pub fn build(self) -> Result, Box> { + let Some(store) = self.store else { + return Err("Vector store not present".into()); + }; + + let threshold = self.threshold.unwrap_or(0.9); + + Ok(SemanticRouter { store, threshold }) + } +}