feat: Improve `InMemoryVectorStore` API (#130)

* feat: Improve `InMemoryVectorStore` API

* style: clippy+fmt

* test: fix test
This commit is contained in:
cvauclair 2024-11-29 14:14:35 -05:00 committed by GitHub
parent f4268aba5d
commit 398433a191
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 534 additions and 110 deletions

View File

@ -251,9 +251,9 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;
let index = InMemoryVectorStore::default()
.add_documents_with_id(embeddings, |tool| tool.name.clone())?
.index(embedding_model);
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
let index = vector_store.index(embedding_model);
// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client

View File

@ -65,9 +65,11 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;
let index = InMemoryVectorStore::default()
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
.index(embedding_model);
// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents(embeddings);
// Create vector store index
let index = vector_store.index(embedding_model);
let rag_agent = openai_client.agent("gpt-4")
.preamble("

View File

@ -155,9 +155,12 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;
let index = InMemoryVectorStore::default()
.add_documents_with_id(embeddings, |tool| tool.name.clone())?
.index(embedding_model);
// Create vector store with the embeddings
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
// Create vector store index
let index = vector_store.index(embedding_model);
// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client

View File

@ -24,9 +24,9 @@ async fn main() -> Result<(), anyhow::Error> {
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let embeddings = EmbeddingsBuilder::new(model.clone())
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(vec![
WordDefinition {
id: "doc0".to_string(),
@ -56,9 +56,12 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;
let index = InMemoryVectorStore::default()
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
.index(model);
// Create vector store with the embeddings
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());
// Create vector store index
let index = vector_store.index(embedding_model);
let results = index
.top_n::<WordDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1)

View File

@ -57,9 +57,12 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;
let index = InMemoryVectorStore::default()
.add_documents_with_id(embeddings, |definition| definition.id.clone())?
.index(search_model);
// Create vector store with the embeddings
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());
// Create vector store index
let index = vector_store.index(search_model);
let results = index
.top_n::<WordDefinition>(

View File

@ -0,0 +1,261 @@
use std::future::Future;
use crate::{completion, vector_store};
use super::{agent_ops, op};
// pub struct PipelineBuilder<E> {
// _error: std::marker::PhantomData<E>,
// }
pub struct PipelineBuilder;
impl PipelineBuilder {
/// Chain a function to the current pipeline
///
/// # Example
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// let chain = pipeline::new()
/// .map(|(x, y)| x + y)
/// .map(|z| format!("Result: {z}!"));
///
/// let result = chain.call((1, 2)).await;
/// assert_eq!(result, "Result: 3!");
/// ```
pub fn map<F, In, T>(self, f: F) -> op::Map<F, In>
where
F: Fn(In) -> T + Send + Sync,
In: Send + Sync,
T: Send + Sync,
Self: Sized,
{
op::Map::new(f)
}
/// Same as `map` but for asynchronous functions
///
/// # Example
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// let chain = pipeline::new()
/// .then(|email: String| async move {
/// email.split('@').next().unwrap().to_string()
/// })
/// .then(|username: String| async move {
/// format!("Hello, {}!", username)
/// });
///
/// let result = chain.call("bob@gmail.com".to_string()).await;
/// assert_eq!(result, "Hello, bob!");
/// ```
pub fn then<F, In, Fut>(self, f: F) -> op::Then<F, In>
where
F: Fn(In) -> Fut + Send + Sync,
In: Send + Sync,
Fut: Future + Send + Sync,
Fut::Output: Send + Sync,
Self: Sized,
{
op::Then::new(f)
}
/// Chain an arbitrary operation to the current pipeline.
///
/// # Example
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// struct MyOp;
///
/// impl Op for MyOp {
/// type Input = i32;
/// type Output = i32;
///
/// async fn call(&self, input: Self::Input) -> Self::Output {
/// input + 1
/// }
/// }
///
/// let chain = pipeline::new()
/// .chain(MyOp);
///
/// let result = chain.call(1).await;
/// assert_eq!(result, 2);
/// ```
pub fn chain<T>(self, op: T) -> T
where
T: op::Op,
Self: Sized,
{
op
}
/// Chain a lookup operation to the current chain. The lookup operation expects the
/// current chain to output a query string. The lookup operation will use the query to
/// retrieve the top `n` documents from the index and return them with the query string.
///
/// # Example
/// ```rust
/// use rig::chain::{self, Chain};
///
/// let chain = chain::new()
/// .lookup(index, 2)
/// .chain(|(query, docs): (_, Vec<String>)| async move {
/// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
/// });
///
/// let result = chain.call("What is a flurbo?".to_string()).await;
/// ```
pub fn lookup<I, In, T>(
self,
index: I,
n: usize,
) -> agent_ops::Lookup<I, In, T>
where
I: vector_store::VectorStoreIndex,
T: Send + Sync + for<'a> serde::Deserialize<'a>,
In: Into<String> + Send + Sync,
Self: Sized,
{
agent_ops::Lookup::new(index, n)
}
/// Chain a prompt operation to the current chain. The prompt operation expects the
/// current chain to output a string. The prompt operation will use the string to prompt
/// the given agent (or any other type that implements the `Prompt` trait) and return
/// the response.
///
/// # Example
/// ```rust
/// use rig::chain::{self, Chain};
///
/// let agent = &openai_client.agent("gpt-4").build();
///
/// let chain = chain::new()
/// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
/// .prompt(agent);
///
/// let result = chain.call("Alice".to_string()).await;
/// ```
pub fn prompt<P, In>(
self,
prompt: P,
) -> impl op::Op<Input = In, Output = Result<String, completion::PromptError>>
where
P: completion::Prompt,
In: Into<String> + Send + Sync,
Self: Sized,
{
agent_ops::prompt(prompt)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChainError {
#[error("Failed to prompt agent: {0}")]
PromptError(#[from] completion::PromptError),
#[error("Failed to lookup documents: {0}")]
LookupError(#[from] vector_store::VectorStoreError),
}
// pub fn new() -> PipelineBuilder<ChainError> {
// PipelineBuilder {
// _error: std::marker::PhantomData,
// }
// }
// pub fn with_error<E>() -> PipelineBuilder<E> {
// PipelineBuilder {
// _error: std::marker::PhantomData,
// }
// }
pub fn new() -> PipelineBuilder {
PipelineBuilder
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::{op::Op, parallel::{parallel, Parallel}};
use agent_ops::tests::{Foo, MockIndex, MockModel};
#[tokio::test]
async fn test_prompt_pipeline() {
let model = MockModel;
let chain = super::new()
.map(|input| format!("User query: {}", input))
.prompt(model);
let result = chain
.call("What is a flurbo?")
.await
.expect("Failed to run chain");
assert_eq!(result, "Mock response: User query: What is a flurbo?");
}
// #[tokio::test]
// async fn test_lookup_pipeline() {
// let index = MockIndex;
// let chain = super::new()
// .lookup::<_, _, Foo>(index, 1)
// .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo));
// let result = chain
// .try_call("What is a flurbo?")
// .await
// .expect("Failed to run chain");
// assert_eq!(
// result,
// "User query: What is a flurbo?\n\nTop documents:\nbar"
// );
// }
#[tokio::test]
async fn test_rag_pipeline() {
let index = MockIndex;
let chain = super::new()
.chain(parallel!(op::passthrough(), agent_ops::lookup::<_, _, Foo>(index, 1),))
.map(|(query, maybe_docs)| match maybe_docs {
Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo),
Err(err) => format!("Error: {}", err),
})
.prompt(MockModel);
let result = chain
.call("What is a flurbo?")
.await
.expect("Failed to run chain");
assert_eq!(
result,
"Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
);
}
#[tokio::test]
async fn test_parallel_chain_compile_check() {
let _ = super::new().chain(
Parallel::new(
op::map(|x: i32| x + 1),
Parallel::new(
op::map(|x: i32| x * 3),
Parallel::new(
op::map(|x: i32| format!("{} is the number!", x)),
op::map(|x: i32| x == 1),
),
),
)
.map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)),
);
}
}

View File

@ -13,7 +13,7 @@ use crate::{
OneOrMany,
};
/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings
/// [InMemoryVectorStore] is a simple in-memory vector store that stores embeddings
/// in-memory using a HashMap.
#[derive(Clone, Default)]
pub struct InMemoryVectorStore<D: Serialize> {
@ -24,8 +24,49 @@ pub struct InMemoryVectorStore<D: Serialize> {
}
impl<D: Serialize + Eq> InMemoryVectorStore<D> {
/// Implement vector search on InMemoryVectorStore.
/// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore.
/// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings.
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
/// is the index of the document.
pub fn from_documents(documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>) -> Self {
let mut store = HashMap::new();
documents
.into_iter()
.enumerate()
.for_each(|(i, (doc, embeddings))| {
store.insert(format!("doc{i}"), (doc, embeddings));
});
Self { embeddings: store }
}
/// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids.
pub fn from_documents_with_ids(
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
) -> Self {
let mut store = HashMap::new();
documents.into_iter().for_each(|(i, doc, embeddings)| {
store.insert(i.to_string(), (doc, embeddings));
});
Self { embeddings: store }
}
/// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings.
/// Document ids are generated using the provided function.
pub fn from_documents_with_id_f(
documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>,
f: fn(&D) -> String,
) -> Self {
let mut store = HashMap::new();
documents.into_iter().for_each(|(doc, embeddings)| {
store.insert(f(&doc), (doc, embeddings));
});
Self { embeddings: store }
}
/// Implement vector search on [InMemoryVectorStore].
/// To be used by implementations of [VectorStoreIndex::top_n] and [VectorStoreIndex::top_n_ids] methods.
fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking<D> {
// Sort documents by best embedding distance
let mut docs = BinaryHeap::new();
@ -63,32 +104,44 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> {
docs
}
/// Add documents to the store.
/// Returns the store with the added documents.
/// Add documents and their corresponding embeddings to the store.
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
/// is the index of the document.
pub fn add_documents(
mut self,
documents: Vec<(String, D, OneOrMany<Embedding>)>,
) -> Result<Self, VectorStoreError> {
for (id, doc, embeddings) in documents {
self.embeddings.insert(id, (doc, embeddings));
}
Ok(self)
&mut self,
documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>,
) {
let current_index = self.embeddings.len();
documents
.into_iter()
.enumerate()
.for_each(|(index, (doc, embeddings))| {
self.embeddings
.insert(format!("doc{}", index + current_index), (doc, embeddings));
});
}
/// Add documents to the store. Define a function that takes as input the reference of the document and returns its id.
/// Returns the store with the added documents.
pub fn add_documents_with_id(
mut self,
/// Add documents and their corresponding embeddings to the store with ids.
pub fn add_documents_with_ids(
&mut self,
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
) {
documents.into_iter().for_each(|(id, doc, embeddings)| {
self.embeddings.insert(id.to_string(), (doc, embeddings));
});
}
/// Add documents and their corresponding embeddings to the store.
/// Document ids are generated using the provided function.
pub fn add_documents_with_id_f(
&mut self,
documents: Vec<(D, OneOrMany<Embedding>)>,
id_f: fn(&D) -> String,
) -> Result<Self, VectorStoreError> {
f: fn(&D) -> String,
) {
for (doc, embeddings) in documents {
let id = id_f(&doc);
let id = f(&doc);
self.embeddings.insert(id, (doc, embeddings));
}
Ok(self)
}
/// Get the document by its id and deserialize it into the given type.
@ -215,37 +268,138 @@ mod tests {
use super::{InMemoryVectorStore, RankingItem};
#[test]
fn test_single_embedding() {
let index = InMemoryVectorStore::default()
.add_documents(vec![
fn test_auto_ids() {
let mut vector_store = InMemoryVectorStore::from_documents(vec![
(
"glarb-garb",
OneOrMany::one(Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
}),
),
(
"marble-marble",
OneOrMany::one(Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
}),
),
(
"flumb-flumb",
OneOrMany::one(Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
}),
),
]);
vector_store.add_documents(vec![
(
"brotato",
OneOrMany::one(Embedding {
document: "brotato".to_string(),
vec: vec![0.3, 0.7, 0.1],
}),
),
(
"ping-pong",
OneOrMany::one(Embedding {
document: "ping-pong".to_string(),
vec: vec![0.7, -0.3, 0.0],
}),
),
]);
let mut store = vector_store.embeddings.into_iter().collect::<Vec<_>>();
store.sort_by_key(|(id, _)| id.clone());
assert_eq!(
store,
vec![
(
"doc0".to_string(),
(
"glarb-garb",
OneOrMany::one(Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
})
)
),
(
"doc1".to_string(),
"glarb-garb",
OneOrMany::one(Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
}),
(
"marble-marble",
OneOrMany::one(Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
})
)
),
(
"doc2".to_string(),
"marble-marble",
OneOrMany::one(Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
}),
(
"flumb-flumb",
OneOrMany::one(Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
})
)
),
(
"doc3".to_string(),
"flumb-flumb",
OneOrMany::one(Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
}),
(
"brotato",
OneOrMany::one(Embedding {
document: "brotato".to_string(),
vec: vec![0.3, 0.7, 0.1],
})
)
),
])
.unwrap();
(
"doc4".to_string(),
(
"ping-pong",
OneOrMany::one(Embedding {
document: "ping-pong".to_string(),
vec: vec![0.7, -0.3, 0.0],
})
)
)
]
);
}
let ranking = index.vector_search(
#[test]
fn test_single_embedding() {
let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![
(
"doc1",
"glarb-garb",
OneOrMany::one(Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
}),
),
(
"doc2",
"marble-marble",
OneOrMany::one(Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
}),
),
(
"doc3",
"flumb-flumb",
OneOrMany::one(Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
}),
),
]);
let ranking = vector_store.vector_search(
&Embedding {
document: "glarby-glarble".to_string(),
vec: vec![0.0, 0.1, 0.6],
@ -274,57 +428,55 @@ mod tests {
#[test]
fn test_multiple_embeddings() {
let index = InMemoryVectorStore::default()
.add_documents(vec![
(
"doc1".to_string(),
"glarb-garb",
OneOrMany::many(vec![
Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
},
Embedding {
document: "don't-choose-me".to_string(),
vec: vec![-0.5, 0.9, 0.1],
},
])
.unwrap(),
),
(
"doc2".to_string(),
"marble-marble",
OneOrMany::many(vec![
Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
},
Embedding {
document: "sandwich".to_string(),
vec: vec![0.5, 0.5, -0.7],
},
])
.unwrap(),
),
(
"doc3".to_string(),
"flumb-flumb",
OneOrMany::many(vec![
Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
},
Embedding {
document: "banana".to_string(),
vec: vec![0.1, -0.5, -0.5],
},
])
.unwrap(),
),
])
.unwrap();
let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![
(
"doc1",
"glarb-garb",
OneOrMany::many(vec![
Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
},
Embedding {
document: "don't-choose-me".to_string(),
vec: vec![-0.5, 0.9, 0.1],
},
])
.unwrap(),
),
(
"doc2",
"marble-marble",
OneOrMany::many(vec![
Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
},
Embedding {
document: "sandwich".to_string(),
vec: vec![0.5, 0.5, -0.7],
},
])
.unwrap(),
),
(
"doc3",
"flumb-flumb",
OneOrMany::many(vec![
Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
},
Embedding {
document: "banana".to_string(),
vec: vec![0.1, -0.5, -0.5],
},
])
.unwrap(),
),
]);
let ranking = index.vector_search(
let ranking = vector_store.vector_search(
&Embedding {
document: "glarby-glarble".to_string(),
vec: vec![0.0, 0.1, 0.6],