feat(pinecone): rig pinecone integration (WIP)

This commit is contained in:
Joshua Mo 2025-04-09 22:00:10 +01:00
parent 92c91d23c3
commit d88b7a62a2
4 changed files with 491 additions and 28 deletions

294
Cargo.lock generated
View File

@ -1452,6 +1452,34 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core 0.3.4",
"bitflags 1.3.2",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.32",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 0.1.2",
"tower 0.4.13",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.7.9" version = "0.7.9"
@ -1459,7 +1487,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core 0.4.5",
"bytes", "bytes",
"futures-util", "futures-util",
"http 1.3.1", "http 1.3.1",
@ -1479,6 +1507,23 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.4.5" version = "0.4.5"
@ -5032,6 +5077,18 @@ dependencies = [
"webpki-roots 0.26.8", "webpki-roots 0.26.8",
] ]
[[package]]
name = "hyper-timeout"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
"hyper 0.14.32",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
]
[[package]] [[package]]
name = "hyper-timeout" name = "hyper-timeout"
version = "0.5.2" version = "0.5.2"
@ -5684,8 +5741,8 @@ dependencies = [
"object_store 0.11.2", "object_store 0.11.2",
"permutation", "permutation",
"pin-project", "pin-project",
"prost", "prost 0.13.5",
"prost-types", "prost-types 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"roaring", "roaring",
"serde", "serde",
@ -5744,7 +5801,7 @@ dependencies = [
"num_cpus", "num_cpus",
"object_store 0.11.2", "object_store 0.11.2",
"pin-project", "pin-project",
"prost", "prost 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"roaring", "roaring",
"serde_json", "serde_json",
@ -5778,7 +5835,7 @@ dependencies = [
"lance-core", "lance-core",
"lazy_static", "lazy_static",
"log", "log",
"prost", "prost 0.13.5",
"snafu 0.8.5", "snafu 0.8.5",
"tokio", "tokio",
] ]
@ -5812,9 +5869,9 @@ dependencies = [
"log", "log",
"num-traits", "num-traits",
"paste", "paste",
"prost", "prost 0.13.5",
"prost-build", "prost-build 0.13.5",
"prost-types", "prost-types 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"seq-macro", "seq-macro",
"snafu 0.8.5", "snafu 0.8.5",
@ -5849,9 +5906,9 @@ dependencies = [
"log", "log",
"num-traits", "num-traits",
"object_store 0.11.2", "object_store 0.11.2",
"prost", "prost 0.13.5",
"prost-build", "prost-build 0.13.5",
"prost-types", "prost-types 0.13.5",
"roaring", "roaring",
"snafu 0.8.5", "snafu 0.8.5",
"tempfile", "tempfile",
@ -5898,8 +5955,8 @@ dependencies = [
"moka", "moka",
"num-traits", "num-traits",
"object_store 0.11.2", "object_store 0.11.2",
"prost", "prost 0.13.5",
"prost-build", "prost-build 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"rayon", "rayon",
"roaring", "roaring",
@ -5944,7 +6001,7 @@ dependencies = [
"object_store 0.11.2", "object_store 0.11.2",
"path_abs", "path_abs",
"pin-project", "pin-project",
"prost", "prost 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"shellexpand", "shellexpand",
"snafu 0.8.5", "snafu 0.8.5",
@ -6004,9 +6061,9 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"object_store 0.11.2", "object_store 0.11.2",
"prost", "prost 0.13.5",
"prost-build", "prost-build 0.13.5",
"prost-types", "prost-types 0.13.5",
"rand 0.8.5", "rand 0.8.5",
"rangemap", "rangemap",
"roaring", "roaring",
@ -7696,6 +7753,30 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pinecone-sdk"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f571fcb86d81e70a5de2817a029fa9e52160f66f10d662584b56607ae6c5dab9"
dependencies = [
"anyhow",
"once_cell",
"prost 0.12.6",
"prost-types 0.12.6",
"rand 0.8.5",
"regex",
"reqwest 0.12.15",
"serde",
"serde_json",
"snafu 0.8.5",
"thiserror 1.0.69",
"tokio",
"tonic 0.11.0",
"tonic-build",
"url",
"uuid 1.16.0",
]
[[package]] [[package]]
name = "piper" name = "piper"
version = "0.2.4" version = "0.2.4"
@ -7902,6 +7983,16 @@ dependencies = [
"unarray", "unarray",
] ]
[[package]]
name = "prost"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29"
dependencies = [
"bytes",
"prost-derive 0.12.6",
]
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.13.5" version = "0.13.5"
@ -7909,7 +8000,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive 0.13.5",
]
[[package]]
name = "prost-build"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
"petgraph 0.6.5",
"prettyplease",
"prost 0.12.6",
"prost-types 0.12.6",
"regex",
"syn 2.0.100",
"tempfile",
] ]
[[package]] [[package]]
@ -7925,13 +8037,26 @@ dependencies = [
"once_cell", "once_cell",
"petgraph 0.7.1", "petgraph 0.7.1",
"prettyplease", "prettyplease",
"prost", "prost 0.13.5",
"prost-types", "prost-types 0.13.5",
"regex", "regex",
"syn 2.0.100", "syn 2.0.100",
"tempfile", "tempfile",
] ]
[[package]]
name = "prost-derive"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
dependencies = [
"anyhow",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.13.5" version = "0.13.5"
@ -7945,13 +8070,22 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "prost-types"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0"
dependencies = [
"prost 0.12.6",
]
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.13.5" version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16"
dependencies = [ dependencies = [
"prost", "prost 0.13.5",
] ]
[[package]] [[package]]
@ -7999,15 +8133,15 @@ dependencies = [
"derive_builder", "derive_builder",
"futures", "futures",
"futures-util", "futures-util",
"prost", "prost 0.13.5",
"prost-types", "prost-types 0.13.5",
"reqwest 0.12.15", "reqwest 0.12.15",
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"tonic", "tonic 0.12.3",
] ]
[[package]] [[package]]
@ -8787,6 +8921,22 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
] ]
[[package]]
name = "rig-pinecone"
version = "0.1.0"
dependencies = [
"anyhow",
"httpmock",
"pinecone-sdk",
"prost-types 0.12.6",
"rig-core 0.11.0",
"serde",
"serde_json",
"testcontainers",
"tokio",
"uuid 1.16.0",
]
[[package]] [[package]]
name = "rig-postgres" name = "rig-postgres"
version = "0.1.6" version = "0.1.6"
@ -9156,6 +9306,20 @@ dependencies = [
"sct", "sct",
] ]
[[package]]
name = "rustls"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
"ring 0.17.14",
"rustls-pki-types",
"rustls-webpki 0.102.8",
"subtle",
"zeroize",
]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.25" version = "0.23.25"
@ -9246,6 +9410,17 @@ dependencies = [
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
[[package]]
name = "rustls-webpki"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring 0.17.14",
"rustls-pki-types",
"untrusted 0.9.0",
]
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.103.1" version = "0.103.1"
@ -11048,6 +11223,16 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "tokio-io-timeout"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf"
dependencies = [
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "tokio-macros" name = "tokio-macros"
version = "2.5.0" version = "2.5.0"
@ -11090,6 +11275,17 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.26.2" version = "0.26.2"
@ -11230,6 +11426,37 @@ dependencies = [
"winnow", "winnow",
] ]
[[package]]
name = "tonic"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13"
dependencies = [
"async-stream",
"async-trait",
"axum 0.6.20",
"base64 0.21.7",
"bytes",
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.32",
"hyper-timeout 0.4.1",
"percent-encoding",
"pin-project",
"prost 0.12.6",
"rustls-native-certs 0.7.3",
"rustls-pemfile 2.2.0",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.25.0",
"tokio-stream",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "tonic" name = "tonic"
version = "0.12.3" version = "0.12.3"
@ -11238,7 +11465,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum", "axum 0.7.9",
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"flate2", "flate2",
@ -11247,11 +11474,11 @@ dependencies = [
"http-body 1.0.1", "http-body 1.0.1",
"http-body-util", "http-body-util",
"hyper 1.6.0", "hyper 1.6.0",
"hyper-timeout", "hyper-timeout 0.5.2",
"hyper-util", "hyper-util",
"percent-encoding", "percent-encoding",
"pin-project", "pin-project",
"prost", "prost 0.13.5",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pemfile 2.2.0", "rustls-pemfile 2.2.0",
"socket2", "socket2",
@ -11264,6 +11491,19 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tonic-build"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4ef6dd70a610078cb4e338a0f79d06bc759ff1b22d2120c2ff02ae264ba9c2"
dependencies = [
"prettyplease",
"proc-macro2",
"prost-build 0.12.6",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "tower" name = "tower"
version = "0.4.13" version = "0.4.13"

View File

@ -12,5 +12,5 @@ members = [
"rig-surrealdb", "rig-surrealdb",
"rig-eternalai", "rig-eternalai",
"rig-fastembed", "rig-fastembed",
"rig-bedrock", "rig-bedrock", "rig-pinecone",
] ]

22
rig-pinecone/Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[package]
name = "rig-pinecone"
version = "0.1.0"
edition = "2021"
license = "MIT"
readme = "README.md"
description = "Rig vector store index integration for Pinecone. https://www.pinecone.io/"
repository = "https://github.com/0xPlaygrounds/rig"
[dependencies]
rig-core = { path = "../rig-core", version = "0.11.0" }
serde_json = "1.0.128"
serde = "1.0.210"
pinecone-sdk = "0.1.2"
uuid = { version = "1.13.1", features = ["v4"] }
prost-types = "0.12"
[dev-dependencies]
tokio = { version = "1.40.0", features = ["rt-multi-thread"] }
anyhow = "1.0.89"
testcontainers = "0.23.1"
httpmock = "0.7.0"

201
rig-pinecone/src/lib.rs Normal file
View File

@ -0,0 +1,201 @@
use std::collections::BTreeMap;
use pinecone_sdk::models::{Kind, Metadata, Namespace, QueryResponse, Value, Vector};
use pinecone_sdk::pinecone::data::Index;
use pinecone_sdk::pinecone::PineconeClient;
use prost_types::ListValue;
use rig::embeddings::EmbeddingModel;
use rig::vector_store::{VectorStoreError, VectorStoreIndex};
use rig::{embeddings::Embedding, Embed, OneOrMany};
use serde::Serialize;
use serde_json::Value as JsonValue;
pub struct PineconeVectorStore<M> {
model: M,
client: PineconeClient,
index_name: String,
namespace: Namespace,
}
impl<M> PineconeVectorStore<M>
where
M: EmbeddingModel,
{
pub fn new<S, N>(client: PineconeClient, index_name: S, model: M, namespace: N) -> Self
where
S: Into<String>,
N: Into<Namespace>,
{
let index_name: String = index_name.into();
let namespace: Namespace = namespace.into();
Self {
client,
model,
index_name,
namespace,
}
}
pub fn update_index_name(&mut self, index_name: &str) {
self.index_name = index_name.to_string();
}
pub fn namespace(&self) -> &Namespace {
&self.namespace
}
pub fn update_namespace(&mut self, namespace: Namespace) {
self.namespace = namespace;
}
pub async fn insert_documents<Doc: Serialize + Embed + Send>(
&mut self,
documents: Vec<(Doc, OneOrMany<Embedding>)>,
namespace: &Namespace,
) -> Result<(), VectorStoreError> {
let vectors: Vec<Vector> = documents
.into_iter()
.map(|(doc, embedding)| {
let metadata = {
let json_value: JsonValue = serde_json::to_value(&doc).unwrap();
json_to_metadata(&json_value)
};
let values = embedding.first().vec.iter().map(|&x| x as f32).collect();
Vector {
id: uuid::Uuid::new_v4().to_string(),
values,
sparse_values: None,
metadata: Some(metadata),
}
})
.collect();
let mut idx = self
.client
.index(&self.index_name)
.await
.map_err(|x| VectorStoreError::DatastoreError(x.into()))?;
idx.upsert(&vectors, namespace)
.await
.map_err(|x| VectorStoreError::DatastoreError(x.into()))?;
Ok(())
}
/// Embed query based on `QdrantVectorStore` model and modify the vector in the required format.
pub async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
let embedding = self.model.embed_text(query).await?;
Ok(embedding.vec.iter().map(|&x| x as f32).collect())
}
}
impl<M> VectorStoreIndex for PineconeVectorStore<M>
where
M: EmbeddingModel,
{
async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let vector = self.generate_query_vector(query).await?;
let mut index = self
.client
.index(&self.index_name)
.await
.map_err(|x| VectorStoreError::DatastoreError(x.into()))?;
let res: QueryResponse = index
.query_by_value(
vector,
None,
n as u32,
self.namespace(),
None,
Some(true),
Some(true),
)
.await
.map_err(|x| VectorStoreError::DatastoreError(x.into()))?;
todo!()
}
async fn top_n_ids(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
todo!()
}
}
pub fn json_to_metadata(json: &JsonValue) -> Metadata {
match json {
JsonValue::Object(map) => {
let fields = map
.iter()
.map(|(k, v)| (k.clone(), json_to_kind_value(v)))
.collect();
Metadata { fields }
}
_ => {
// Not a JSON object — return empty metadata or panic based on your needs
Metadata {
fields: BTreeMap::new(),
}
}
}
}
fn json_to_kind_value(json: &JsonValue) -> Value {
let kind = match json {
JsonValue::Null => Some(Kind::NullValue(0)),
JsonValue::Bool(b) => Some(Kind::BoolValue(*b)),
JsonValue::Number(n) => n
.as_f64()
.map(Kind::NumberValue)
.or_else(|| n.as_i64().map(|i| Kind::NumberValue(i as f64)))
.or_else(|| n.as_u64().map(|u| Kind::NumberValue(u as f64))),
JsonValue::String(s) => Some(Kind::StringValue(s.clone())),
JsonValue::Array(arr) => Some(Kind::ListValue(ListValue {
values: arr.iter().map(json_to_kind_value).collect(),
})),
JsonValue::Object(map) => Some(Kind::StructValue(Metadata {
fields: map
.iter()
.map(|(k, v)| (k.clone(), json_to_kind_value(v)))
.collect(),
})),
};
Value { kind }
}
pub fn metadata_to_json_value(metadata: &Metadata) -> serde_json::Value {
let mut map = serde_json::Map::new();
for (k, v) in &metadata.fields {
map.insert(k.clone(), convert_value_to_json(v));
}
serde_json::Value::Object(map)
}
fn convert_value_to_json(value: &Value) -> serde_json::Value {
match &value.kind {
Some(Kind::NullValue(_)) => serde_json::Value::Null,
Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b),
Some(Kind::NumberValue(n)) => serde_json::Value::Number(
serde_json::Number::from_f64(*n).expect("Invalid f64 for JSON number"),
),
Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()),
Some(Kind::ListValue(list)) => {
let arr = list.values.iter().map(convert_value_to_json).collect();
serde_json::Value::Array(arr)
}
Some(Kind::StructValue(struct_val)) => metadata_to_json_value(struct_val),
None => serde_json::Value::Null,
}
}