From d88b7a62a2f5454b977501346062bdf2e4aa8603 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Wed, 9 Apr 2025 22:00:10 +0100 Subject: [PATCH] feat(pinecone): rig pinecone integration (WIP) --- Cargo.lock | 294 ++++++++++++++++++++++++++++++++++++---- Cargo.toml | 2 +- rig-pinecone/Cargo.toml | 22 +++ rig-pinecone/src/lib.rs | 201 +++++++++++++++++++++++++++ 4 files changed, 491 insertions(+), 28 deletions(-) create mode 100644 rig-pinecone/Cargo.toml create mode 100644 rig-pinecone/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index d647af4..fb9324e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1452,6 +1452,34 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core 0.3.4", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 0.1.2", + "tower 0.4.13", + "tower-layer", + "tower-service", +] + [[package]] name = "axum" version = "0.7.9" @@ -1459,7 +1487,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "bytes", "futures-util", "http 1.3.1", @@ -1479,6 +1507,23 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -5032,6 +5077,18 @@ dependencies = [ "webpki-roots 0.26.8", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.32", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-timeout" version = "0.5.2" @@ -5684,8 +5741,8 @@ dependencies = [ "object_store 0.11.2", "permutation", "pin-project", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", "rand 0.8.5", "roaring", "serde", @@ -5744,7 +5801,7 @@ dependencies = [ "num_cpus", "object_store 0.11.2", "pin-project", - "prost", + "prost 0.13.5", "rand 0.8.5", "roaring", "serde_json", @@ -5778,7 +5835,7 @@ dependencies = [ "lance-core", "lazy_static", "log", - "prost", + "prost 0.13.5", "snafu 0.8.5", "tokio", ] @@ -5812,9 +5869,9 @@ dependencies = [ "log", "num-traits", "paste", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "rand 0.8.5", "seq-macro", "snafu 0.8.5", @@ -5849,9 +5906,9 @@ dependencies = [ "log", "num-traits", "object_store 0.11.2", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "roaring", "snafu 0.8.5", "tempfile", @@ -5898,8 +5955,8 @@ dependencies = [ "moka", "num-traits", "object_store 0.11.2", - "prost", - "prost-build", + "prost 0.13.5", + "prost-build 0.13.5", "rand 0.8.5", "rayon", "roaring", @@ -5944,7 +6001,7 @@ dependencies = [ "object_store 0.11.2", "path_abs", "pin-project", - "prost", + "prost 0.13.5", "rand 0.8.5", "shellexpand", "snafu 0.8.5", @@ -6004,9 +6061,9 @@ dependencies = [ "lazy_static", "log", "object_store 0.11.2", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "rand 0.8.5", "rangemap", "roaring", @@ -7696,6 +7753,30 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pinecone-sdk" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f571fcb86d81e70a5de2817a029fa9e52160f66f10d662584b56607ae6c5dab9" +dependencies = [ + "anyhow", + "once_cell", + "prost 0.12.6", + "prost-types 0.12.6", + "rand 0.8.5", + "regex", + "reqwest 0.12.15", + "serde", + "serde_json", + "snafu 0.8.5", + "thiserror 1.0.69", + "tokio", + "tonic 0.11.0", + "tonic-build", + "url", + "uuid 1.16.0", +] + [[package]] name = "piper" version = "0.2.4" @@ -7902,6 +7983,16 @@ dependencies = [ "unarray", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive 0.12.6", +] + [[package]] name = "prost" version = "0.13.5" @@ -7909,7 +8000,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost-build" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" +dependencies = [ + "bytes", + "heck 0.5.0", + "itertools 0.12.1", + "log", + "multimap", + "once_cell", + "petgraph 0.6.5", + "prettyplease", + "prost 0.12.6", + "prost-types 0.12.6", + "regex", + "syn 2.0.100", + "tempfile", ] [[package]] @@ -7925,13 +8037,26 @@ dependencies = [ "once_cell", "petgraph 0.7.1", "prettyplease", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", "regex", "syn 2.0.100", "tempfile", ] +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "prost-derive" version = "0.13.5" @@ -7945,13 +8070,22 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost 0.12.6", +] + [[package]] name = "prost-types" version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ - "prost", + "prost 0.13.5", ] [[package]] @@ -7999,15 +8133,15 @@ dependencies = [ "derive_builder", "futures", "futures-util", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", "reqwest 0.12.15", "semver", "serde", "serde_json", "thiserror 1.0.69", "tokio", - "tonic", + "tonic 0.12.3", ] [[package]] @@ -8787,6 +8921,22 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "rig-pinecone" +version = "0.1.0" +dependencies = [ + "anyhow", + "httpmock", + "pinecone-sdk", + "prost-types 0.12.6", + "rig-core 0.11.0", + "serde", + "serde_json", + "testcontainers", + "tokio", + "uuid 1.16.0", +] + [[package]] name = "rig-postgres" version = "0.1.6" @@ -9156,6 +9306,20 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring 0.17.14", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + [[package]] name = "rustls" version = "0.23.25" @@ -9246,6 +9410,17 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring 0.17.14", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustls-webpki" version = "0.103.1" @@ -11048,6 +11223,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.5.0" @@ -11090,6 +11275,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.4", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" @@ -11230,6 +11426,37 @@ dependencies = [ "winnow", ] +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.6.20", + "base64 0.21.7", + "bytes", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-timeout 0.4.1", + "percent-encoding", + "pin-project", + "prost 0.12.6", + "rustls-native-certs 0.7.3", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.25.0", + "tokio-stream", + "tower 0.4.13", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tonic" version = "0.12.3" @@ -11238,7 +11465,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.9", "base64 0.22.1", "bytes", "flate2", @@ -11247,11 +11474,11 @@ dependencies = [ "http-body 1.0.1", "http-body-util", "hyper 1.6.0", - "hyper-timeout", + "hyper-timeout 0.5.2", "hyper-util", "percent-encoding", "pin-project", - "prost", + "prost 0.13.5", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "socket2", @@ -11264,6 +11491,19 @@ dependencies = [ "tracing", ] +[[package]] +name = "tonic-build" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4ef6dd70a610078cb4e338a0f79d06bc759ff1b22d2120c2ff02ae264ba9c2" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build 0.12.6", + "quote", + "syn 2.0.100", +] + [[package]] name = "tower" version = "0.4.13" diff --git a/Cargo.toml b/Cargo.toml index a110254..c193426 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,5 @@ members = [ "rig-surrealdb", "rig-eternalai", "rig-fastembed", - "rig-bedrock", + "rig-bedrock", "rig-pinecone", ] diff --git a/rig-pinecone/Cargo.toml b/rig-pinecone/Cargo.toml new file mode 100644 index 0000000..741b33b --- /dev/null +++ b/rig-pinecone/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rig-pinecone" +version = "0.1.0" +edition = "2021" +license = "MIT" +readme = "README.md" +description = "Rig vector store index integration for Pinecone. https://www.pinecone.io/" +repository = "https://github.com/0xPlaygrounds/rig" + +[dependencies] +rig-core = { path = "../rig-core", version = "0.11.0" } +serde_json = "1.0.128" +serde = "1.0.210" +pinecone-sdk = "0.1.2" +uuid = { version = "1.13.1", features = ["v4"] } +prost-types = "0.12" + +[dev-dependencies] +tokio = { version = "1.40.0", features = ["rt-multi-thread"] } +anyhow = "1.0.89" +testcontainers = "0.23.1" +httpmock = "0.7.0" diff --git a/rig-pinecone/src/lib.rs b/rig-pinecone/src/lib.rs new file mode 100644 index 0000000..95ec75d --- /dev/null +++ b/rig-pinecone/src/lib.rs @@ -0,0 +1,201 @@ +use std::collections::BTreeMap; + +use pinecone_sdk::models::{Kind, Metadata, Namespace, QueryResponse, Value, Vector}; +use pinecone_sdk::pinecone::data::Index; +use pinecone_sdk::pinecone::PineconeClient; +use prost_types::ListValue; +use rig::embeddings::EmbeddingModel; +use rig::vector_store::{VectorStoreError, VectorStoreIndex}; +use rig::{embeddings::Embedding, Embed, OneOrMany}; +use serde::Serialize; +use serde_json::Value as JsonValue; + +pub struct PineconeVectorStore { + model: M, + client: PineconeClient, + index_name: String, + namespace: Namespace, +} + +impl PineconeVectorStore +where + M: EmbeddingModel, +{ + pub fn new(client: PineconeClient, index_name: S, model: M, namespace: N) -> Self + where + S: Into, + N: Into, + { + let index_name: String = index_name.into(); + let namespace: Namespace = namespace.into(); + Self { + client, + model, + index_name, + namespace, + } + } + + pub fn update_index_name(&mut self, index_name: &str) { + self.index_name = index_name.to_string(); + } + + pub fn namespace(&self) -> &Namespace { + &self.namespace + } + + pub fn update_namespace(&mut self, namespace: Namespace) { + self.namespace = namespace; + } + + pub async fn insert_documents( + &mut self, + documents: Vec<(Doc, OneOrMany)>, + namespace: &Namespace, + ) -> Result<(), VectorStoreError> { + let vectors: Vec = documents + .into_iter() + .map(|(doc, embedding)| { + let metadata = { + let json_value: JsonValue = serde_json::to_value(&doc).unwrap(); + json_to_metadata(&json_value) + }; + + let values = embedding.first().vec.iter().map(|&x| x as f32).collect(); + + Vector { + id: uuid::Uuid::new_v4().to_string(), + values, + sparse_values: None, + metadata: Some(metadata), + } + }) + .collect(); + + let mut idx = self + .client + .index(&self.index_name) + .await + .map_err(|x| VectorStoreError::DatastoreError(x.into()))?; + + idx.upsert(&vectors, namespace) + .await + .map_err(|x| VectorStoreError::DatastoreError(x.into()))?; + + Ok(()) + } + + /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format. + pub async fn generate_query_vector(&self, query: &str) -> Result, VectorStoreError> { + let embedding = self.model.embed_text(query).await?; + Ok(embedding.vec.iter().map(|&x| x as f32).collect()) + } +} + +impl VectorStoreIndex for PineconeVectorStore +where + M: EmbeddingModel, +{ + async fn top_n serde::Deserialize<'a> + Send>( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + let vector = self.generate_query_vector(query).await?; + let mut index = self + .client + .index(&self.index_name) + .await + .map_err(|x| VectorStoreError::DatastoreError(x.into()))?; + + let res: QueryResponse = index + .query_by_value( + vector, + None, + n as u32, + self.namespace(), + None, + Some(true), + Some(true), + ) + .await + .map_err(|x| VectorStoreError::DatastoreError(x.into()))?; + + todo!() + } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + todo!() + } +} + +pub fn json_to_metadata(json: &JsonValue) -> Metadata { + match json { + JsonValue::Object(map) => { + let fields = map + .iter() + .map(|(k, v)| (k.clone(), json_to_kind_value(v))) + .collect(); + Metadata { fields } + } + _ => { + // Not a JSON object — return empty metadata or panic based on your needs + Metadata { + fields: BTreeMap::new(), + } + } + } +} + +fn json_to_kind_value(json: &JsonValue) -> Value { + let kind = match json { + JsonValue::Null => Some(Kind::NullValue(0)), + JsonValue::Bool(b) => Some(Kind::BoolValue(*b)), + JsonValue::Number(n) => n + .as_f64() + .map(Kind::NumberValue) + .or_else(|| n.as_i64().map(|i| Kind::NumberValue(i as f64))) + .or_else(|| n.as_u64().map(|u| Kind::NumberValue(u as f64))), + JsonValue::String(s) => Some(Kind::StringValue(s.clone())), + JsonValue::Array(arr) => Some(Kind::ListValue(ListValue { + values: arr.iter().map(json_to_kind_value).collect(), + })), + JsonValue::Object(map) => Some(Kind::StructValue(Metadata { + fields: map + .iter() + .map(|(k, v)| (k.clone(), json_to_kind_value(v))) + .collect(), + })), + }; + + Value { kind } +} + +pub fn metadata_to_json_value(metadata: &Metadata) -> serde_json::Value { + let mut map = serde_json::Map::new(); + for (k, v) in &metadata.fields { + map.insert(k.clone(), convert_value_to_json(v)); + } + serde_json::Value::Object(map) +} + +fn convert_value_to_json(value: &Value) -> serde_json::Value { + match &value.kind { + Some(Kind::NullValue(_)) => serde_json::Value::Null, + Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b), + Some(Kind::NumberValue(n)) => serde_json::Value::Number( + serde_json::Number::from_f64(*n).expect("Invalid f64 for JSON number"), + ), + Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()), + Some(Kind::ListValue(list)) => { + let arr = list.values.iter().map(convert_value_to_json).collect(); + serde_json::Value::Array(arr) + } + Some(Kind::StructValue(struct_val)) => metadata_to_json_value(struct_val), + None => serde_json::Value::Null, + } +}