mirror of https://github.com/0xplaygrounds/rig
feat: Improve `InMemoryVectorStore` API (#130)
* feat: Improve `InMemoryVectorStore` API * style: clippy+fmt * test: fix test
This commit is contained in:
parent
f4268aba5d
commit
398433a191
|
@ -251,9 +251,9 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build()
|
||||
.await?;
|
||||
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents_with_id(embeddings, |tool| tool.name.clone())?
|
||||
.index(embedding_model);
|
||||
let vector_store =
|
||||
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
|
||||
let index = vector_store.index(embedding_model);
|
||||
|
||||
// Create RAG agent with a single context prompt and a dynamic tool source
|
||||
let calculator_rag = openai_client
|
||||
|
|
|
@ -65,9 +65,11 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build()
|
||||
.await?;
|
||||
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
|
||||
.index(embedding_model);
|
||||
// Create vector store with the embeddings
|
||||
let vector_store = InMemoryVectorStore::from_documents(embeddings);
|
||||
|
||||
// Create vector store index
|
||||
let index = vector_store.index(embedding_model);
|
||||
|
||||
let rag_agent = openai_client.agent("gpt-4")
|
||||
.preamble("
|
||||
|
|
|
@ -155,9 +155,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build()
|
||||
.await?;
|
||||
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents_with_id(embeddings, |tool| tool.name.clone())?
|
||||
.index(embedding_model);
|
||||
// Create vector store with the embeddings
|
||||
let vector_store =
|
||||
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
|
||||
|
||||
// Create vector store index
|
||||
let index = vector_store.index(embedding_model);
|
||||
|
||||
// Create RAG agent with a single context prompt and a dynamic tool source
|
||||
let calculator_rag = openai_client
|
||||
|
|
|
@ -24,9 +24,9 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
let openai_client = Client::new(&openai_api_key);
|
||||
|
||||
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
|
||||
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
|
||||
|
||||
let embeddings = EmbeddingsBuilder::new(model.clone())
|
||||
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
|
||||
.documents(vec![
|
||||
WordDefinition {
|
||||
id: "doc0".to_string(),
|
||||
|
@ -56,9 +56,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build()
|
||||
.await?;
|
||||
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
|
||||
.index(model);
|
||||
// Create vector store with the embeddings
|
||||
let vector_store =
|
||||
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());
|
||||
|
||||
// Create vector store index
|
||||
let index = vector_store.index(embedding_model);
|
||||
|
||||
let results = index
|
||||
.top_n::<WordDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1)
|
||||
|
|
|
@ -57,9 +57,12 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
.build()
|
||||
.await?;
|
||||
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
|
||||
.index(search_model);
|
||||
// Create vector store with the embeddings
|
||||
let vector_store =
|
||||
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());
|
||||
|
||||
// Create vector store index
|
||||
let index = vector_store.index(search_model);
|
||||
|
||||
let results = index
|
||||
.top_n::<WordDefinition>(
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
use std::future::Future;
|
||||
|
||||
use crate::{completion, vector_store};
|
||||
|
||||
use super::{agent_ops, op};
|
||||
|
||||
// pub struct PipelineBuilder<E> {
|
||||
// _error: std::marker::PhantomData<E>,
|
||||
// }
|
||||
pub struct PipelineBuilder;
|
||||
|
||||
impl PipelineBuilder {
|
||||
/// Chain a function to the current pipeline
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use rig::pipeline::{self, Op};
|
||||
///
|
||||
/// let chain = pipeline::new()
|
||||
/// .map(|(x, y)| x + y)
|
||||
/// .map(|z| format!("Result: {z}!"));
|
||||
///
|
||||
/// let result = chain.call((1, 2)).await;
|
||||
/// assert_eq!(result, "Result: 3!");
|
||||
/// ```
|
||||
pub fn map<F, In, T>(self, f: F) -> op::Map<F, In>
|
||||
where
|
||||
F: Fn(In) -> T + Send + Sync,
|
||||
In: Send + Sync,
|
||||
T: Send + Sync,
|
||||
Self: Sized,
|
||||
{
|
||||
op::Map::new(f)
|
||||
}
|
||||
|
||||
/// Same as `map` but for asynchronous functions
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use rig::pipeline::{self, Op};
|
||||
///
|
||||
/// let chain = pipeline::new()
|
||||
/// .then(|email: String| async move {
|
||||
/// email.split('@').next().unwrap().to_string()
|
||||
/// })
|
||||
/// .then(|username: String| async move {
|
||||
/// format!("Hello, {}!", username)
|
||||
/// });
|
||||
///
|
||||
/// let result = chain.call("bob@gmail.com".to_string()).await;
|
||||
/// assert_eq!(result, "Hello, bob!");
|
||||
/// ```
|
||||
pub fn then<F, In, Fut>(self, f: F) -> op::Then<F, In>
|
||||
where
|
||||
F: Fn(In) -> Fut + Send + Sync,
|
||||
In: Send + Sync,
|
||||
Fut: Future + Send + Sync,
|
||||
Fut::Output: Send + Sync,
|
||||
Self: Sized,
|
||||
{
|
||||
op::Then::new(f)
|
||||
}
|
||||
|
||||
/// Chain an arbitrary operation to the current pipeline.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use rig::pipeline::{self, Op};
|
||||
///
|
||||
/// struct MyOp;
|
||||
///
|
||||
/// impl Op for MyOp {
|
||||
/// type Input = i32;
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// async fn call(&self, input: Self::Input) -> Self::Output {
|
||||
/// input + 1
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let chain = pipeline::new()
|
||||
/// .chain(MyOp);
|
||||
///
|
||||
/// let result = chain.call(1).await;
|
||||
/// assert_eq!(result, 2);
|
||||
/// ```
|
||||
pub fn chain<T>(self, op: T) -> T
|
||||
where
|
||||
T: op::Op,
|
||||
Self: Sized,
|
||||
{
|
||||
op
|
||||
}
|
||||
|
||||
/// Chain a lookup operation to the current chain. The lookup operation expects the
|
||||
/// current chain to output a query string. The lookup operation will use the query to
|
||||
/// retrieve the top `n` documents from the index and return them with the query string.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use rig::chain::{self, Chain};
|
||||
///
|
||||
/// let chain = chain::new()
|
||||
/// .lookup(index, 2)
|
||||
/// .chain(|(query, docs): (_, Vec<String>)| async move {
|
||||
/// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
|
||||
/// });
|
||||
///
|
||||
/// let result = chain.call("What is a flurbo?".to_string()).await;
|
||||
/// ```
|
||||
pub fn lookup<I, In, T>(
|
||||
self,
|
||||
index: I,
|
||||
n: usize,
|
||||
) -> agent_ops::Lookup<I, In, T>
|
||||
where
|
||||
I: vector_store::VectorStoreIndex,
|
||||
T: Send + Sync + for<'a> serde::Deserialize<'a>,
|
||||
In: Into<String> + Send + Sync,
|
||||
Self: Sized,
|
||||
{
|
||||
agent_ops::Lookup::new(index, n)
|
||||
}
|
||||
|
||||
/// Chain a prompt operation to the current chain. The prompt operation expects the
|
||||
/// current chain to output a string. The prompt operation will use the string to prompt
|
||||
/// the given agent (or any other type that implements the `Prompt` trait) and return
|
||||
/// the response.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use rig::chain::{self, Chain};
|
||||
///
|
||||
/// let agent = &openai_client.agent("gpt-4").build();
|
||||
///
|
||||
/// let chain = chain::new()
|
||||
/// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
|
||||
/// .prompt(agent);
|
||||
///
|
||||
/// let result = chain.call("Alice".to_string()).await;
|
||||
/// ```
|
||||
pub fn prompt<P, In>(
|
||||
self,
|
||||
prompt: P,
|
||||
) -> impl op::Op<Input = In, Output = Result<String, completion::PromptError>>
|
||||
where
|
||||
P: completion::Prompt,
|
||||
In: Into<String> + Send + Sync,
|
||||
Self: Sized,
|
||||
{
|
||||
agent_ops::prompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChainError {
|
||||
#[error("Failed to prompt agent: {0}")]
|
||||
PromptError(#[from] completion::PromptError),
|
||||
|
||||
#[error("Failed to lookup documents: {0}")]
|
||||
LookupError(#[from] vector_store::VectorStoreError),
|
||||
}
|
||||
|
||||
// pub fn new() -> PipelineBuilder<ChainError> {
|
||||
// PipelineBuilder {
|
||||
// _error: std::marker::PhantomData,
|
||||
// }
|
||||
// }
|
||||
|
||||
// pub fn with_error<E>() -> PipelineBuilder<E> {
|
||||
// PipelineBuilder {
|
||||
// _error: std::marker::PhantomData,
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn new() -> PipelineBuilder {
|
||||
PipelineBuilder
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::pipeline::{op::Op, parallel::{parallel, Parallel}};
|
||||
use agent_ops::tests::{Foo, MockIndex, MockModel};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prompt_pipeline() {
|
||||
let model = MockModel;
|
||||
|
||||
let chain = super::new()
|
||||
.map(|input| format!("User query: {}", input))
|
||||
.prompt(model);
|
||||
|
||||
let result = chain
|
||||
.call("What is a flurbo?")
|
||||
.await
|
||||
.expect("Failed to run chain");
|
||||
|
||||
assert_eq!(result, "Mock response: User query: What is a flurbo?");
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_lookup_pipeline() {
|
||||
// let index = MockIndex;
|
||||
|
||||
// let chain = super::new()
|
||||
// .lookup::<_, _, Foo>(index, 1)
|
||||
// .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo));
|
||||
|
||||
// let result = chain
|
||||
// .try_call("What is a flurbo?")
|
||||
// .await
|
||||
// .expect("Failed to run chain");
|
||||
|
||||
// assert_eq!(
|
||||
// result,
|
||||
// "User query: What is a flurbo?\n\nTop documents:\nbar"
|
||||
// );
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rag_pipeline() {
|
||||
let index = MockIndex;
|
||||
|
||||
let chain = super::new()
|
||||
.chain(parallel!(op::passthrough(), agent_ops::lookup::<_, _, Foo>(index, 1),))
|
||||
.map(|(query, maybe_docs)| match maybe_docs {
|
||||
Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo),
|
||||
Err(err) => format!("Error: {}", err),
|
||||
})
|
||||
.prompt(MockModel);
|
||||
|
||||
let result = chain
|
||||
.call("What is a flurbo?")
|
||||
.await
|
||||
.expect("Failed to run chain");
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parallel_chain_compile_check() {
|
||||
let _ = super::new().chain(
|
||||
Parallel::new(
|
||||
op::map(|x: i32| x + 1),
|
||||
Parallel::new(
|
||||
op::map(|x: i32| x * 3),
|
||||
Parallel::new(
|
||||
op::map(|x: i32| format!("{} is the number!", x)),
|
||||
op::map(|x: i32| x == 1),
|
||||
),
|
||||
),
|
||||
)
|
||||
.map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)),
|
||||
);
|
||||
}
|
||||
}
|
|
@ -13,7 +13,7 @@ use crate::{
|
|||
OneOrMany,
|
||||
};
|
||||
|
||||
/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings
|
||||
/// [InMemoryVectorStore] is a simple in-memory vector store that stores embeddings
|
||||
/// in-memory using a HashMap.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct InMemoryVectorStore<D: Serialize> {
|
||||
|
@ -24,8 +24,49 @@ pub struct InMemoryVectorStore<D: Serialize> {
|
|||
}
|
||||
|
||||
impl<D: Serialize + Eq> InMemoryVectorStore<D> {
|
||||
/// Implement vector search on InMemoryVectorStore.
|
||||
/// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore.
|
||||
/// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings.
|
||||
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
|
||||
/// is the index of the document.
|
||||
pub fn from_documents(documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>) -> Self {
|
||||
let mut store = HashMap::new();
|
||||
documents
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.for_each(|(i, (doc, embeddings))| {
|
||||
store.insert(format!("doc{i}"), (doc, embeddings));
|
||||
});
|
||||
|
||||
Self { embeddings: store }
|
||||
}
|
||||
|
||||
/// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids.
|
||||
pub fn from_documents_with_ids(
|
||||
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
|
||||
) -> Self {
|
||||
let mut store = HashMap::new();
|
||||
documents.into_iter().for_each(|(i, doc, embeddings)| {
|
||||
store.insert(i.to_string(), (doc, embeddings));
|
||||
});
|
||||
|
||||
Self { embeddings: store }
|
||||
}
|
||||
|
||||
/// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings.
|
||||
/// Document ids are generated using the provided function.
|
||||
pub fn from_documents_with_id_f(
|
||||
documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>,
|
||||
f: fn(&D) -> String,
|
||||
) -> Self {
|
||||
let mut store = HashMap::new();
|
||||
documents.into_iter().for_each(|(doc, embeddings)| {
|
||||
store.insert(f(&doc), (doc, embeddings));
|
||||
});
|
||||
|
||||
Self { embeddings: store }
|
||||
}
|
||||
|
||||
/// Implement vector search on [InMemoryVectorStore].
|
||||
/// To be used by implementations of [VectorStoreIndex::top_n] and [VectorStoreIndex::top_n_ids] methods.
|
||||
fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking<D> {
|
||||
// Sort documents by best embedding distance
|
||||
let mut docs = BinaryHeap::new();
|
||||
|
@ -63,32 +104,44 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> {
|
|||
docs
|
||||
}
|
||||
|
||||
/// Add documents to the store.
|
||||
/// Returns the store with the added documents.
|
||||
/// Add documents and their corresponding embeddings to the store.
|
||||
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
|
||||
/// is the index of the document.
|
||||
pub fn add_documents(
|
||||
mut self,
|
||||
documents: Vec<(String, D, OneOrMany<Embedding>)>,
|
||||
) -> Result<Self, VectorStoreError> {
|
||||
for (id, doc, embeddings) in documents {
|
||||
self.embeddings.insert(id, (doc, embeddings));
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
&mut self,
|
||||
documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>,
|
||||
) {
|
||||
let current_index = self.embeddings.len();
|
||||
documents
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.for_each(|(index, (doc, embeddings))| {
|
||||
self.embeddings
|
||||
.insert(format!("doc{}", index + current_index), (doc, embeddings));
|
||||
});
|
||||
}
|
||||
|
||||
/// Add documents to the store. Define a function that takes as input the reference of the document and returns its id.
|
||||
/// Returns the store with the added documents.
|
||||
pub fn add_documents_with_id(
|
||||
mut self,
|
||||
/// Add documents and their corresponding embeddings to the store with ids.
|
||||
pub fn add_documents_with_ids(
|
||||
&mut self,
|
||||
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
|
||||
) {
|
||||
documents.into_iter().for_each(|(id, doc, embeddings)| {
|
||||
self.embeddings.insert(id.to_string(), (doc, embeddings));
|
||||
});
|
||||
}
|
||||
|
||||
/// Add documents and their corresponding embeddings to the store.
|
||||
/// Document ids are generated using the provided function.
|
||||
pub fn add_documents_with_id_f(
|
||||
&mut self,
|
||||
documents: Vec<(D, OneOrMany<Embedding>)>,
|
||||
id_f: fn(&D) -> String,
|
||||
) -> Result<Self, VectorStoreError> {
|
||||
f: fn(&D) -> String,
|
||||
) {
|
||||
for (doc, embeddings) in documents {
|
||||
let id = id_f(&doc);
|
||||
let id = f(&doc);
|
||||
self.embeddings.insert(id, (doc, embeddings));
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Get the document by its id and deserialize it into the given type.
|
||||
|
@ -215,37 +268,138 @@ mod tests {
|
|||
use super::{InMemoryVectorStore, RankingItem};
|
||||
|
||||
#[test]
|
||||
fn test_single_embedding() {
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents(vec![
|
||||
fn test_auto_ids() {
|
||||
let mut vector_store = InMemoryVectorStore::from_documents(vec![
|
||||
(
|
||||
"glarb-garb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
}),
|
||||
),
|
||||
(
|
||||
"marble-marble",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
}),
|
||||
),
|
||||
(
|
||||
"flumb-flumb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
}),
|
||||
),
|
||||
]);
|
||||
|
||||
vector_store.add_documents(vec![
|
||||
(
|
||||
"brotato",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "brotato".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
}),
|
||||
),
|
||||
(
|
||||
"ping-pong",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "ping-pong".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
}),
|
||||
),
|
||||
]);
|
||||
|
||||
let mut store = vector_store.embeddings.into_iter().collect::<Vec<_>>();
|
||||
store.sort_by_key(|(id, _)| id.clone());
|
||||
|
||||
assert_eq!(
|
||||
store,
|
||||
vec![
|
||||
(
|
||||
"doc0".to_string(),
|
||||
(
|
||||
"glarb-garb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
})
|
||||
)
|
||||
),
|
||||
(
|
||||
"doc1".to_string(),
|
||||
"glarb-garb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
}),
|
||||
(
|
||||
"marble-marble",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
})
|
||||
)
|
||||
),
|
||||
(
|
||||
"doc2".to_string(),
|
||||
"marble-marble",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
}),
|
||||
(
|
||||
"flumb-flumb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
})
|
||||
)
|
||||
),
|
||||
(
|
||||
"doc3".to_string(),
|
||||
"flumb-flumb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
}),
|
||||
(
|
||||
"brotato",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "brotato".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
})
|
||||
)
|
||||
),
|
||||
])
|
||||
.unwrap();
|
||||
(
|
||||
"doc4".to_string(),
|
||||
(
|
||||
"ping-pong",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "ping-pong".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
})
|
||||
)
|
||||
)
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
let ranking = index.vector_search(
|
||||
#[test]
|
||||
fn test_single_embedding() {
|
||||
let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![
|
||||
(
|
||||
"doc1",
|
||||
"glarb-garb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
}),
|
||||
),
|
||||
(
|
||||
"doc2",
|
||||
"marble-marble",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
}),
|
||||
),
|
||||
(
|
||||
"doc3",
|
||||
"flumb-flumb",
|
||||
OneOrMany::one(Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
}),
|
||||
),
|
||||
]);
|
||||
|
||||
let ranking = vector_store.vector_search(
|
||||
&Embedding {
|
||||
document: "glarby-glarble".to_string(),
|
||||
vec: vec![0.0, 0.1, 0.6],
|
||||
|
@ -274,57 +428,55 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_multiple_embeddings() {
|
||||
let index = InMemoryVectorStore::default()
|
||||
.add_documents(vec![
|
||||
(
|
||||
"doc1".to_string(),
|
||||
"glarb-garb",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
},
|
||||
Embedding {
|
||||
document: "don't-choose-me".to_string(),
|
||||
vec: vec![-0.5, 0.9, 0.1],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
(
|
||||
"doc2".to_string(),
|
||||
"marble-marble",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
},
|
||||
Embedding {
|
||||
document: "sandwich".to_string(),
|
||||
vec: vec![0.5, 0.5, -0.7],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
(
|
||||
"doc3".to_string(),
|
||||
"flumb-flumb",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
},
|
||||
Embedding {
|
||||
document: "banana".to_string(),
|
||||
vec: vec![0.1, -0.5, -0.5],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
])
|
||||
.unwrap();
|
||||
let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![
|
||||
(
|
||||
"doc1",
|
||||
"glarb-garb",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "glarb-garb".to_string(),
|
||||
vec: vec![0.1, 0.1, 0.5],
|
||||
},
|
||||
Embedding {
|
||||
document: "don't-choose-me".to_string(),
|
||||
vec: vec![-0.5, 0.9, 0.1],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
(
|
||||
"doc2",
|
||||
"marble-marble",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "marble-marble".to_string(),
|
||||
vec: vec![0.7, -0.3, 0.0],
|
||||
},
|
||||
Embedding {
|
||||
document: "sandwich".to_string(),
|
||||
vec: vec![0.5, 0.5, -0.7],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
(
|
||||
"doc3",
|
||||
"flumb-flumb",
|
||||
OneOrMany::many(vec![
|
||||
Embedding {
|
||||
document: "flumb-flumb".to_string(),
|
||||
vec: vec![0.3, 0.7, 0.1],
|
||||
},
|
||||
Embedding {
|
||||
document: "banana".to_string(),
|
||||
vec: vec![0.1, -0.5, -0.5],
|
||||
},
|
||||
])
|
||||
.unwrap(),
|
||||
),
|
||||
]);
|
||||
|
||||
let ranking = index.vector_search(
|
||||
let ranking = vector_store.vector_search(
|
||||
&Embedding {
|
||||
document: "glarby-glarble".to_string(),
|
||||
vec: vec![0.0, 0.1, 0.6],
|
||||
|
|
Loading…
Reference in New Issue