Merge branch 'main' into feat/model-provider/16-add-gemini-completion-embedding-models

This commit is contained in:
Mathieu 2024-11-01 12:49:46 +01:00
parent 9ae5e33b75
commit caee495dbe
32 changed files with 1709 additions and 524 deletions

View File

@ -5,9 +5,6 @@ on:
pull_request:
branches:
- main
push:
branches:
- main
workflow_call:
env:
@ -55,6 +52,8 @@ jobs:
- name: Run clippy action
uses: clechasseur/rs-clippy-check@v3
with:
args: --all-features
test:
name: stable / test
@ -79,4 +78,25 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: nextest
args: run --all-features
args: run --all-features
doc:
name: stable / doc
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Rust stable
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: rust-docs
# Required to compile rig-lancedb
- name: Install Protoc
uses: arduino/setup-protoc@v3
- name: Run cargo doc
run: cargo doc --no-deps --all-features
env:
RUSTDOCFLAGS: -D warnings

921
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,10 @@
<p align="center">
<img src="img/rig_logo.svg" alt="Rig Logo" style="width: 75%; height: 75%;"><br>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="img/rig_logo_dark.svg">
<source media="(prefers-color-scheme: light)" srcset="img/rig_logo.svg">
<img src="img/rig_logo.svg" style="width: 40%; height: 40%;" alt="Rig logo">
</picture>
<br>
<a href="https://crates.io/crates/rig-core"><img src="https://img.shields.io/crates/v/rig-core.svg" /></a>
&nbsp;
<a href="https://discord.gg/playgrounds"><img src="https://img.shields.io/discord/511303648119226382?color=%236d82cc&label=Discord&logo=discord&logoColor=white" /></a>
@ -12,7 +17,7 @@
</p>
&nbsp;
> [!WARNING]
> [!WARNING]
> Here be dragons! Rig is **alpha** software and **will** contain breaking changes as it evolves. We'll annotate them and highlight migration paths as we encounter them.
@ -70,7 +75,10 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul
Rig supports the following LLM providers natively:
- OpenAI
- Cohere
- Anthropic
- Perplexity
- Google Gemini
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
- LanceDB vector store: `rig-lancedb`

View File

@ -3,8 +3,8 @@
<defs>
<style>
.cls-1 {
fill: white;
stroke-width: 0px;
fill: black;
stroke-width: 10px;
}
</style>
</defs>
@ -17,4 +17,4 @@
</g>
<path class="cls-1"
d="M1821.1,424.7c0-6-1.4-11.6-3.8-16.6-5.5-11.8-16.6-19.9-29.5-19.9h-130.7c-.7,0-1.5.1-2.1.2-.7,0-1.4-.2-2.1-.2h-138.6c-11,0-19.9,8.9-19.9,19.9v68.4c0,11,8.9,19.9,19.9,19.9h101.5c-7.3,13.2-16.5,25.7-27.7,37-47,47.7-115.2,61.7-174.8,41.9-23.9-8-46.3-21.3-65.5-40.2-63.3-62.4-67.4-162.3-11.5-229.5,3-3.7,6.3-7.2,9.7-10.7,65.8-66.9,173.4-67.7,240.3-1.8,7.7,7.7,14.6,15.9,20.6,24.6,3.7,5.4,9.9,8.6,16.5,8.6h99.2c14.5,0,24.4-14.8,18.7-28.2-5.5-13-11.9-25.6-19.3-37.8,0,0,0,0,0,0l43.2-43.9c8.4-8.6,8.3-22.4-.2-30.9l-63.1-62.2-13.4-13.2c-8.6-8.4-22.4-8.3-30.9.2l-12.8,13-31.2,31.7c-22.5-12.8-46.3-22.4-70.8-28.8h0s0-2.9,0-2.9l-.5-62c0-12-10-21.8-22-21.7l-107.3.8c-12,0-21.8,10-21.7,22l.5,60.9v2.7c-26.1,6.8-51.5,17.1-75.3,31.1l-34.4-33.8-9.2-9.1c-13.1-12.9-34.3-12.7-47.2.3l-8.6,8.8-31.3,31.8-19.2,19.5c-12.9,13.1-12.7,34.3.3,47.2l42.9,42.3c-11.3,19.3-20.2,39.5-26.8,60.3h0c-8.2-2.7-17.7-.9-24.2,5.7-2.7,2.7-4.5,6-5.5,9.4l-30.3.3c-18.4.1-33.3,15.2-33.1,33.6l.6,84.3c.1,18.4,15.3,33.3,33.6,33.1l55.1-.4c6.3,24,15.6,47.3,28.1,69.5h0c-3.8,6-4.5,13.2-2.4,19.7l-28.4,28.9c-12.9,13.1-12.8,34.3.3,47.2l52.4,51.6,7.7,7.5c13.1,12.9,34.3,12.7,47.2-.3l7.1-7.2,31.9-32.3s0,0,0,0c18.7,10.3,38.1,18.5,58.1,24.4,5,1.5,10,2.9,15,4.1v3.8s.5,61.6.5,61.6c0,12,10,21.7,22,21.7l107.4-.8c12,0,21.7-10,21.7-22l-.4-60.4v-6.2c25.6-7.1,50.4-17.7,73.8-31.9h0l38.6,38.1,7.9,7.7c8.6,8.4,22.4,8.3,30.9-.2l7.4-7.5,67.9-69c8.4-8.5,8.3-22.4-.2-30.9l-47.3-46.6c11.8-20.6,20.7-42.3,27.1-64.6h2.5s32.4-.3,32.4-.3c18.2-.1,33.1-16.8,33-37.1l-.3-34ZM1101.1,453.1l-.6-72.9c-.1-14.5,11.7-26.5,26.2-26.6l26.6-.3c1.1,3.5,3,6.8,5.8,9.6.9.9,1.9,1.7,3,2.4l-49.1,109.5c-7.1-4.7-11.8-12.6-11.9-21.7ZM1178,479.9h-1.3c0,0-49-.6-49-.6-1.7,0-3.4-.1-5-.5l49.2-109.6c2,.3,4.1.3,6.2.1v110.6ZM1308.4,504.7l2.7,37.4c-1.5.2-2.9.6-4.3,1.1l-62.9-108.9c3-3.9,4.7-8.5,4.9-13.2h36c1,29,8.8,57.7,23.6,83.6ZM1268.4,307.8c1.1,1.1,2.4,2.1,3.7,3l-40.6,87c-1.9-.5-3.8-.8-5.7-.8l-4.5-123.8,41.5,10.7c-2.8,8.2-.9,17.5,5.6,24ZM1279.5,315.9v96.3h-32c-1.1-3.2-2.9-6.1-5.5-8.6-.8-.8-1.6-1.4-2.5-2l40-85.6ZM1237,440.3l62.1,107.5c-.4.3-.8.6-1.1.9-3.4,3.5-5.5,7.8-6.3,12.2l-43.5,1.9c-1.1-2.6-2.7-5.1-4.9-7.2-3.4-3.4-7.6-5.4-12-6.2l-3.8-106.3c3.3-.2,6.6-1.2,9.6-2.9ZM1288.5,377.9v-63.9c4.7-.8,9.2-2.9,12.8-6.6,3.4-3.5,5.5-7.7,6.3-12.2l18.5,4.8c-19.2,23.1-31.8,49.9-37.6,77.9ZM1206.2,244.9c-.3-.3-.5-.5-.6-.8h0s-25.6-25.2-25.6-25.2c-10.4-10.2-10.5-27-.3-37.3l16.2-16.5,1.3-1.3,50,49.3s0,0,0,0l9.6,9.4s0,0,0,0l9.8,9.7,58.9,56.9-18.8-4.8c-1.1-3.5-3-6.9-5.8-9.7-9-8.9-23.4-8.9-32.4-.1l-43.6-11.2-18.8-18.3ZM1189.3,481.2v-116.3c.9-.7,1.9-1.5,2.7-2.3,8.3-8.3,8.8-21.3,1.8-30.4,4.8-15.9,11.2-31.2,18.9-46.2l4.1,112.5c-2.8,1.2-5.4,2.9-7.7,5.2-9,9.1-8.9,23.9.2,32.9,2.6,2.6,5.7,4.3,9,5.4l3.9,107.5c-2,.4-4,1.1-5.8,2-12.3-22.4-21.3-46.1-27.1-70.4ZM1190.2,614l22.9-23.2c2.6,1.9,5.4,3.3,8.4,4,0,0,.1,0,.2,0,1.3.3,2.7.5,4,.6.6,0,1.2,0,1.7,0,1.1,0,2.2-.1,3.3-.3h0c1,0,2-.3,3.1-.7,6.6-2.1,9.7-6.3,12.7-9.5,2.6-4,3.8-8.6,3.7-13.2l41.6-1.8c.9,4.3,3,8.4,6.4,11.7,1.7,1.7,3.6,2.9,5.6,3.9l-13.7,42.4c-6.8-.7-13.9,1.4-19,6.7-2.1,2.2-3.7,4.7-4.8,7.3l-83.4-7.1c-.6-7.4,1.8-15.1,7.4-20.8ZM1248.7,707.1c-2.3-1.2-4.4-2.7-6.3-4.5l-52-51.2c-2.2-2.2-3.9-4.7-5.2-7.3l79.1,6.8c0,6.1,2.2,12.1,6.9,16.7,1.5,1.4,3.1,2.6,4.8,3.5l-12.6,38.9c-5,.5-10.2-.5-14.7-2.9ZM1313.6,669.3l-33.8,33c-1.8,1.8-3.7,3.2-5.8,4.4l10.6-32.8c7,.9,14.3-1.3,19.6-6.7,9-9.1,8.9-23.9-.2-32.9-1.6-1.6-3.4-2.8-5.3-3.8l13.8-42.5c.7,0,1.3.2,2,.2l5.4,75.5c-1.8,1.7-3.9,3.5-6.3,5.5h0ZM1386,687.7c-.6-.2-1.1-.3-1.6-.5-17.4-5.2-34.5-12.2-50.9-20.8l62.7-19.7-10.2,41ZM1398.8,636.6l-70.3,22.1-5.2-72c2.9-1.2,5.5-2.9,7.8-5.2,9-9.1,8.9-23.9-.2-32.9-3.1-3-6.8-5-10.8-5.9l-1.6-22c5.9,8.2,12.5,16,20,23.4,20.9,20.6,45.6,35.1,71.7,43.4l.8.3-12.3,49Z" />
</svg>
</svg>

Before

Width:  |  Height:  |  Size: 4.7 KiB

After

Width:  |  Height:  |  Size: 4.7 KiB

20
img/rig_logo_dark.svg Normal file
View File

@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1899 830">
<defs>
<style>
.cls-1 {
fill: white;
stroke-width: 0px;
}
</style>
</defs>
<g>
<path class="cls-1"
d="M613.5,698h-154.5c-3.2,0-6.3-1.6-8.1-4.3l-93.7-135.7c-1.8-2.6-4.7-4.2-7.9-4.3l-22.9-.6c-5.3-.1-9.5-4.4-9.6-9.7l-1.8-110.3c0-5.5,4.3-10,9.8-10l57.8-.4c.4,0,.8,0,1.2,0,6.9-.9,71.8-11.5,72.7-76.5,1.1-76.8-68.7-85.6-74.9-86.2-.3,0-.6,0-.9,0l-57.5-.7c-5.4,0-9.7-4.4-9.7-9.8l-.4-98c0-11.5,9.3-20.8,20.8-20.8h80.3c.2,0,.4,0,.6,0,2.2.1,4.4.3,6.7.4,87.9,5.7,203.3,72.7,203.3,209.4s-74.2,164.3-100.9,179.5c-5,2.8-6.4,9.3-3.2,14.1l101.1,148.6c4.5,6.5-.2,15.4-8.2,15.4Z" />
<path class="cls-1"
d="M952.7,150.3v527.8c0,10.9-8.9,19.8-19.8,19.8h-110.8c-11.5,0-20.8-9.3-20.8-20.8V150.3c0-10.9,8.9-19.8,19.8-19.8h111.8c10.9,0,19.8,8.9,19.8,19.8Z" />
<rect class="cls-1" x="98.7" y="130.6" width="151.3" height="567.4" rx="20.8" ry="20.8" />
</g>
<path class="cls-1"
d="M1821.1,424.7c0-6-1.4-11.6-3.8-16.6-5.5-11.8-16.6-19.9-29.5-19.9h-130.7c-.7,0-1.5.1-2.1.2-.7,0-1.4-.2-2.1-.2h-138.6c-11,0-19.9,8.9-19.9,19.9v68.4c0,11,8.9,19.9,19.9,19.9h101.5c-7.3,13.2-16.5,25.7-27.7,37-47,47.7-115.2,61.7-174.8,41.9-23.9-8-46.3-21.3-65.5-40.2-63.3-62.4-67.4-162.3-11.5-229.5,3-3.7,6.3-7.2,9.7-10.7,65.8-66.9,173.4-67.7,240.3-1.8,7.7,7.7,14.6,15.9,20.6,24.6,3.7,5.4,9.9,8.6,16.5,8.6h99.2c14.5,0,24.4-14.8,18.7-28.2-5.5-13-11.9-25.6-19.3-37.8,0,0,0,0,0,0l43.2-43.9c8.4-8.6,8.3-22.4-.2-30.9l-63.1-62.2-13.4-13.2c-8.6-8.4-22.4-8.3-30.9.2l-12.8,13-31.2,31.7c-22.5-12.8-46.3-22.4-70.8-28.8h0s0-2.9,0-2.9l-.5-62c0-12-10-21.8-22-21.7l-107.3.8c-12,0-21.8,10-21.7,22l.5,60.9v2.7c-26.1,6.8-51.5,17.1-75.3,31.1l-34.4-33.8-9.2-9.1c-13.1-12.9-34.3-12.7-47.2.3l-8.6,8.8-31.3,31.8-19.2,19.5c-12.9,13.1-12.7,34.3.3,47.2l42.9,42.3c-11.3,19.3-20.2,39.5-26.8,60.3h0c-8.2-2.7-17.7-.9-24.2,5.7-2.7,2.7-4.5,6-5.5,9.4l-30.3.3c-18.4.1-33.3,15.2-33.1,33.6l.6,84.3c.1,18.4,15.3,33.3,33.6,33.1l55.1-.4c6.3,24,15.6,47.3,28.1,69.5h0c-3.8,6-4.5,13.2-2.4,19.7l-28.4,28.9c-12.9,13.1-12.8,34.3.3,47.2l52.4,51.6,7.7,7.5c13.1,12.9,34.3,12.7,47.2-.3l7.1-7.2,31.9-32.3s0,0,0,0c18.7,10.3,38.1,18.5,58.1,24.4,5,1.5,10,2.9,15,4.1v3.8s.5,61.6.5,61.6c0,12,10,21.7,22,21.7l107.4-.8c12,0,21.7-10,21.7-22l-.4-60.4v-6.2c25.6-7.1,50.4-17.7,73.8-31.9h0l38.6,38.1,7.9,7.7c8.6,8.4,22.4,8.3,30.9-.2l7.4-7.5,67.9-69c8.4-8.5,8.3-22.4-.2-30.9l-47.3-46.6c11.8-20.6,20.7-42.3,27.1-64.6h2.5s32.4-.3,32.4-.3c18.2-.1,33.1-16.8,33-37.1l-.3-34ZM1101.1,453.1l-.6-72.9c-.1-14.5,11.7-26.5,26.2-26.6l26.6-.3c1.1,3.5,3,6.8,5.8,9.6.9.9,1.9,1.7,3,2.4l-49.1,109.5c-7.1-4.7-11.8-12.6-11.9-21.7ZM1178,479.9h-1.3c0,0-49-.6-49-.6-1.7,0-3.4-.1-5-.5l49.2-109.6c2,.3,4.1.3,6.2.1v110.6ZM1308.4,504.7l2.7,37.4c-1.5.2-2.9.6-4.3,1.1l-62.9-108.9c3-3.9,4.7-8.5,4.9-13.2h36c1,29,8.8,57.7,23.6,83.6ZM1268.4,307.8c1.1,1.1,2.4,2.1,3.7,3l-40.6,87c-1.9-.5-3.8-.8-5.7-.8l-4.5-123.8,41.5,10.7c-2.8,8.2-.9,17.5,5.6,24ZM1279.5,315.9v96.3h-32c-1.1-3.2-2.9-6.1-5.5-8.6-.8-.8-1.6-1.4-2.5-2l40-85.6ZM1237,440.3l62.1,107.5c-.4.3-.8.6-1.1.9-3.4,3.5-5.5,7.8-6.3,12.2l-43.5,1.9c-1.1-2.6-2.7-5.1-4.9-7.2-3.4-3.4-7.6-5.4-12-6.2l-3.8-106.3c3.3-.2,6.6-1.2,9.6-2.9ZM1288.5,377.9v-63.9c4.7-.8,9.2-2.9,12.8-6.6,3.4-3.5,5.5-7.7,6.3-12.2l18.5,4.8c-19.2,23.1-31.8,49.9-37.6,77.9ZM1206.2,244.9c-.3-.3-.5-.5-.6-.8h0s-25.6-25.2-25.6-25.2c-10.4-10.2-10.5-27-.3-37.3l16.2-16.5,1.3-1.3,50,49.3s0,0,0,0l9.6,9.4s0,0,0,0l9.8,9.7,58.9,56.9-18.8-4.8c-1.1-3.5-3-6.9-5.8-9.7-9-8.9-23.4-8.9-32.4-.1l-43.6-11.2-18.8-18.3ZM1189.3,481.2v-116.3c.9-.7,1.9-1.5,2.7-2.3,8.3-8.3,8.8-21.3,1.8-30.4,4.8-15.9,11.2-31.2,18.9-46.2l4.1,112.5c-2.8,1.2-5.4,2.9-7.7,5.2-9,9.1-8.9,23.9.2,32.9,2.6,2.6,5.7,4.3,9,5.4l3.9,107.5c-2,.4-4,1.1-5.8,2-12.3-22.4-21.3-46.1-27.1-70.4ZM1190.2,614l22.9-23.2c2.6,1.9,5.4,3.3,8.4,4,0,0,.1,0,.2,0,1.3.3,2.7.5,4,.6.6,0,1.2,0,1.7,0,1.1,0,2.2-.1,3.3-.3h0c1,0,2-.3,3.1-.7,6.6-2.1,9.7-6.3,12.7-9.5,2.6-4,3.8-8.6,3.7-13.2l41.6-1.8c.9,4.3,3,8.4,6.4,11.7,1.7,1.7,3.6,2.9,5.6,3.9l-13.7,42.4c-6.8-.7-13.9,1.4-19,6.7-2.1,2.2-3.7,4.7-4.8,7.3l-83.4-7.1c-.6-7.4,1.8-15.1,7.4-20.8ZM1248.7,707.1c-2.3-1.2-4.4-2.7-6.3-4.5l-52-51.2c-2.2-2.2-3.9-4.7-5.2-7.3l79.1,6.8c0,6.1,2.2,12.1,6.9,16.7,1.5,1.4,3.1,2.6,4.8,3.5l-12.6,38.9c-5,.5-10.2-.5-14.7-2.9ZM1313.6,669.3l-33.8,33c-1.8,1.8-3.7,3.2-5.8,4.4l10.6-32.8c7,.9,14.3-1.3,19.6-6.7,9-9.1,8.9-23.9-.2-32.9-1.6-1.6-3.4-2.8-5.3-3.8l13.8-42.5c.7,0,1.3.2,2,.2l5.4,75.5c-1.8,1.7-3.9,3.5-6.3,5.5h0ZM1386,687.7c-.6-.2-1.1-.3-1.6-.5-17.4-5.2-34.5-12.2-50.9-20.8l62.7-19.7-10.2,41ZM1398.8,636.6l-70.3,22.1-5.2-72c2.9-1.2,5.5-2.9,7.8-5.2,9-9.1,8.9-23.9-.2-32.9-3.1-3-6.8-5-10.8-5.9l-1.6-22c5.9,8.2,12.5,16,20,23.4,20.9,20.6,45.6,35.1,71.7,43.4l.8.3-12.3,49Z" />
</svg>

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@ -7,6 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.3.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.1...rig-core-v0.3.0) - 2024-10-24
### Added
- Generalize `EmbeddingModel::embed_documents` with `IntoIterator`
- Add `from_env` constructor to Cohere and Anthropic clients
- Small optimization to serde_json object merging
- Add better error handling for provider clients
### Fixed
- Bad Anthropic request/response handling
- *(vector-index)* In memory vector store index incorrect search
### Other
- Made internal `json_utils` module private
- Update lib docs
- Made CompletionRequest helper method private to crate
- lint + fmt
- Simplify `agent_with_tools` example
- Fix docstring links
- Add nextest test runner to CI
- Merge pull request [#42](https://github.com/0xPlaygrounds/rig/pull/42) from 0xPlaygrounds/refactor(vector-store)/update-vector-store-index-trait
## [0.2.1](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.0...rig-core-v0.2.1) - 2024-10-01
### Fixed

View File

@ -1,6 +1,6 @@
[package]
name = "rig-core"
version = "0.2.1"
version = "0.3.0"
edition = "2021"
license = "MIT"
readme = "README.md"
@ -23,8 +23,14 @@ futures = "0.3.29"
ordered-float = "4.2.0"
schemars = "0.8.16"
thiserror = "1.0.61"
glob = "0.3.1"
lopdf = { version = "0.34.0", optional = true }
[dev-dependencies]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
[features]
pdf = ["dep:lopdf"]

View File

@ -0,0 +1,38 @@
use std::env;
use rig::{
agent::AgentBuilder,
completion::Prompt,
loaders::FileLoader,
providers::openai::{self, GPT_4O},
};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let openai_client =
openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"));
let model = openai_client.completion_model(GPT_4O);
// Load in all the rust examples
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
.read_with_path()
.ignore_errors()
.into_iter();
// Create an agent with multiple context documents
let agent = examples
.fold(AgentBuilder::new(model), |builder, (path, content)| {
builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
})
.build();
// Prompt the agent and print the response
let response = agent
.prompt("Which rust example is best suited for the operation 1 + 2")
.await?;
println!("{}", response);
Ok(())
}

View File

@ -6,7 +6,6 @@ use rig::{
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::env;
#[derive(Deserialize)]
struct OperationArgs {
@ -92,25 +91,13 @@ impl Tool for Subtract {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = providers::openai::Client::new(&openai_api_key);
let openai_client = providers::openai::Client::from_env();
// Create agent with a single context prompt and two tools
let gpt4_calculator_agent = openai_client
.agent("gpt-4")
.context("You are a calculator here to help the user perform arithmetic operations.")
.tool(Adder)
.tool(Subtract)
.build();
// Create OpenAI client
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = providers::cohere::Client::new(&cohere_api_key);
// Create agent with a single context prompt and two tools
let coral_calculator_agent = cohere_client
.agent("command-r")
.preamble("You are a calculator here to help the user perform arithmetic operations.")
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
@ -118,12 +105,8 @@ async fn main() -> Result<(), anyhow::Error> {
// Prompt the agent and print the response
println!("Calculate 2 - 5");
println!(
"GPT-4: {}",
gpt4_calculator_agent.prompt("Calculate 2 - 5").await?
);
println!(
"Coral: {}",
coral_calculator_agent.prompt("Calculate 2 - 5").await?
"Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);
Ok(())

View File

@ -0,0 +1,14 @@
use rig::loaders::FileLoader;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
FileLoader::with_glob("cargo.toml")?
.read()
.into_iter()
.for_each(|result| match result {
Ok(content) => println!("{}", content),
Err(e) => eprintln!("Error reading file: {}", e),
});
Ok(())
}

View File

@ -266,7 +266,7 @@ pub struct CompletionRequest {
}
impl CompletionRequest {
pub fn prompt_with_context(&self) -> String {
pub(crate) fn prompt_with_context(&self) -> String {
if !self.documents.is_empty() {
format!(
"<attachments>\n{}</attachments>\n\n{}",
@ -439,14 +439,14 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
}
/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
/// Note: This is required if using Anthropic
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
/// Note: This is required if using Anthropic
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self

View File

@ -97,7 +97,7 @@ pub trait EmbeddingModel: Clone + Sync + Send {
/// Embed multiple documents in a single request
fn embed_documents(
&self,
documents: Vec<String>,
documents: impl IntoIterator<Item = String> + Send,
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
}

View File

@ -1,11 +1,19 @@
pub fn merge(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value {
match (a.clone(), b) {
(serde_json::Value::Object(mut a), serde_json::Value::Object(b)) => {
b.into_iter().for_each(|(key, value)| {
a.insert(key.clone(), value.clone());
match (a, b) {
(serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => {
b_map.into_iter().for_each(|(key, value)| {
a_map.insert(key, value);
});
serde_json::Value::Object(a)
serde_json::Value::Object(a_map)
}
_ => a,
(a, _) => a,
}
}
pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) {
if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) {
b_map.into_iter().for_each(|(key, value)| {
a_map.insert(key, value);
});
}
}

View File

@ -54,24 +54,27 @@
//! Rig provides a common interface for working with vector stores and indexes. Specifically, the library
//! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex)
//! traits, which can be implemented to define vector stores and indices respectively.
//! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or
//! Those can then be used as the knowledgebase for a RAG enabled [Agent](crate::agent::Agent), or
//! as a source of context documents in a custom architecture that use multiple LLMs or agents.
//!
//! # Integrations
//! Rig natively supports the following completion and embedding model providers:
//! - OpenAI
//! - Cohere
//! - Anthropic
//! - Perplexity
//!
//! Rig currently has the following integration companion crates:
//! - `rig-mongodb`: Vector store implementation for MongoDB
//!
//! - `rig-lancedb`: Vector store implementation for LanceDB
pub mod agent;
pub mod cli_chatbot;
pub mod completion;
pub mod embeddings;
pub mod extractor;
pub mod json_utils;
pub(crate) mod json_utils;
pub mod loaders;
pub mod providers;
pub mod tool;
pub mod vector_store;

View File

@ -0,0 +1,273 @@
use std::{fs, path::PathBuf};
use glob::glob;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum FileLoaderError {
#[error("Invalid glob pattern: {0}")]
InvalidGlobPattern(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Pattern error: {0}")]
PatternError(#[from] glob::PatternError),
#[error("Glob error: {0}")]
GlobError(#[from] glob::GlobError),
}
// ================================================================
// Implementing Readable trait for reading file contents
// ================================================================
pub(crate) trait Readable {
fn read(self) -> Result<String, FileLoaderError>;
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError>;
}
impl<'a> FileLoader<'a, PathBuf> {
pub fn read(self) -> FileLoader<'a, Result<String, FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read())),
}
}
pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read_with_path())),
}
}
}
impl Readable for PathBuf {
fn read(self) -> Result<String, FileLoaderError> {
fs::read_to_string(self).map_err(FileLoaderError::IoError)
}
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> {
let contents = fs::read_to_string(&self);
Ok((self, contents?))
}
}
impl<T: Readable> Readable for Result<T, FileLoaderError> {
fn read(self) -> Result<String, FileLoaderError> {
self.map(|t| t.read())?
}
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> {
self.map(|t| t.read_with_path())?
}
}
// ================================================================
// FileLoader definitions and implementations
// ================================================================
/// [FileLoader] is a utility for loading files from the filesystem using glob patterns or directory
/// paths. It provides methods to read file contents and handle errors gracefully.
///
/// # Errors
///
/// This module defines a custom error type [FileLoaderError] which can represent various errors
/// that might occur during file loading operations, such as invalid glob patterns, IO errors, and
/// glob errors.
///
/// # Example Usage
///
/// ```rust
/// use rig:loaders::FileLoader;
///
/// fn main() -> Result<(), Box<dyn std::error::Error>> {
/// // Create a FileLoader using a glob pattern
/// let loader = FileLoader::with_glob("path/to/files/*.txt")?;
///
/// // Read file contents, ignoring any errors
/// let contents: Vec<String> = loader
/// .read()
/// .ignore_errors()
///
/// for content in contents {
/// println!("{}", content);
/// }
///
/// Ok(())
/// }
/// ```
///
/// [FileLoader] uses strict typing between the iterator methods to ensure that transitions between
/// different implementations of the loaders and it's methods are handled properly by the compiler.
pub struct FileLoader<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}
impl<'a> FileLoader<'a, Result<PathBuf, FileLoaderError>> {
/// Reads the contents of the files within the iterator returned by [FileLoader::with_glob] or
/// [FileLoader::with_dir].
///
/// # Example
/// Read files in directory "files/*.txt" and print the content for each file
///
/// ```rust
/// let content = FileLoader::with_glob(...)?.read();
/// for result in content {
/// match result {
/// Ok(content) => println!("{}", content),
/// Err(e) => eprintln!("Error reading file: {}", e),
/// }
/// }
/// ```
pub fn read(self) -> FileLoader<'a, Result<String, FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read())),
}
}
/// Reads the contents of the files within the iterator returned by [FileLoader::with_glob] or
/// [FileLoader::with_dir] and returns the path along with the content.
///
/// # Example
/// Read files in directory "files/*.txt" and print the content for cooresponding path for each
/// file.
///
/// ```rust
/// let content = FileLoader::with_glob("files/*.txt")?.read();
/// for (path, result) in content {
/// match result {
/// Ok((path, content)) => println!("{:?} {}", path, content),
/// Err(e) => eprintln!("Error reading file: {}", e),
/// }
/// }
/// ```
pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read_with_path())),
}
}
}
impl<'a, T: 'a> FileLoader<'a, Result<T, FileLoaderError>> {
/// Ignores errors in the iterator, returning only successful results. This can be used on any
/// [FileLoader] state of iterator whose items are results.
///
/// # Example
/// Read files in directory "files/*.txt" and ignore errors from unreadable files.
///
/// ```rust
/// let content = FileLoader::with_glob("files/*.txt")?.read().ignore_errors();
/// for result in content {
/// println!("{}", content)
/// }
/// ```
pub fn ignore_errors(self) -> FileLoader<'a, T> {
FileLoader {
iterator: Box::new(self.iterator.filter_map(|res| res.ok())),
}
}
}
impl<'a> FileLoader<'a, Result<PathBuf, FileLoaderError>> {
/// Creates a new [FileLoader] using a glob pattern to match files.
///
/// # Example
/// Create a [FileLoader] for all `.txt` files that match the glob "files/*.txt".
///
/// ```rust
/// let loader = FileLoader::with_glob("files/*.txt")?;
/// ```
pub fn with_glob(
pattern: &str,
) -> Result<FileLoader<Result<PathBuf, FileLoaderError>>, FileLoaderError> {
let paths = glob(pattern)?;
Ok(FileLoader {
iterator: Box::new(
paths
.into_iter()
.map(|path| path.map_err(FileLoaderError::GlobError)),
),
})
}
/// Creates a new [FileLoader] on all files within a directory.
///
/// # Example
/// Create a [FileLoader] for all files that are in the directory "files" (ignores subdirectories).
///
/// ```rust
/// let loader = FileLoader::with_dir("files")?;
/// ```
pub fn with_dir(
directory: &str,
) -> Result<FileLoader<Result<PathBuf, FileLoaderError>>, FileLoaderError> {
Ok(FileLoader {
iterator: Box::new(fs::read_dir(directory)?.filter_map(|entry| {
let path = entry.ok()?.path();
if path.is_file() {
Some(Ok(path))
} else {
None
}
})),
})
}
}
// ================================================================
// Iterators for FileLoader
// ================================================================
pub struct IntoIter<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}
impl<'a, T> IntoIterator for FileLoader<'a, T> {
type Item = T;
type IntoIter = IntoIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
iterator: self.iterator,
}
}
}
impl<'a, T> Iterator for IntoIter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.iterator.next()
}
}
#[cfg(test)]
mod tests {
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use super::FileLoader;
#[test]
fn test_file_loader() {
let temp = assert_fs::TempDir::new().expect("Failed to create temp dir");
let foo_file = temp.child("foo.txt");
let bar_file = temp.child("bar.txt");
foo_file.touch().expect("Failed to create foo.txt");
bar_file.touch().expect("Failed to create bar.txt");
foo_file.write_str("foo").expect("Failed to write to foo");
bar_file.write_str("bar").expect("Failed to write to bar");
let glob = temp.path().to_string_lossy().to_string() + "/*.txt";
let loader = FileLoader::with_glob(&glob).unwrap();
let mut actual = loader
.ignore_errors()
.read()
.ignore_errors()
.into_iter()
.collect::<Vec<_>>();
let mut expected = vec!["foo".to_string(), "bar".to_string()];
actual.sort();
expected.sort();
assert!(!actual.is_empty());
assert!(expected == actual)
}
}

View File

@ -0,0 +1,9 @@
pub mod file;
pub use file::FileLoader;
#[cfg(feature = "pdf")]
pub mod pdf;
#[cfg(feature = "pdf")]
pub use pdf::PdfFileLoader;

456
rig-core/src/loaders/pdf.rs Normal file
View File

@ -0,0 +1,456 @@
use std::{fs, path::PathBuf};
use glob::glob;
use lopdf::{Document, Error as LopdfError};
use thiserror::Error;
use super::file::FileLoaderError;
#[derive(Error, Debug)]
pub enum PdfLoaderError {
#[error("{0}")]
FileLoaderError(#[from] FileLoaderError),
#[error("UTF-8 conversion error: {0}")]
FromUtf8Error(#[from] std::string::FromUtf8Error),
#[error("IO error: {0}")]
PdfError(#[from] LopdfError),
}
// ================================================================
// Implementing Loadable trait for loading pdfs
// ================================================================
pub(crate) trait Loadable {
fn load(self) -> Result<Document, PdfLoaderError>;
fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError>;
}
impl Loadable for PathBuf {
fn load(self) -> Result<Document, PdfLoaderError> {
Document::load(self).map_err(PdfLoaderError::PdfError)
}
fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError> {
let contents = Document::load(&self);
Ok((self, contents?))
}
}
impl<T: Loadable> Loadable for Result<T, PdfLoaderError> {
fn load(self) -> Result<Document, PdfLoaderError> {
self.map(|t| t.load())?
}
fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError> {
self.map(|t| t.load_with_path())?
}
}
// ================================================================
// PdfFileLoader definitions and implementations
// ================================================================
/// [PdfFileLoader] is a utility for loading pdf files from the filesystem using glob patterns or
/// directory paths. It provides methods to read file contents and handle errors gracefully.
///
/// # Errors
///
/// This module defines a custom error type [PdfLoaderError] which can represent various errors
/// that might occur during file loading operations, such as any [FileLoaderError] alongside
/// specific PDF-related errors.
///
/// # Example Usage
///
/// ```rust
/// use rig:loaders::PdfileLoader;
///
/// fn main() -> Result<(), Box<dyn std::error::Error>> {
/// // Create a FileLoader using a glob pattern
/// let loader = PdfFileLoader::with_glob("tests/data/*.pdf")?;
///
/// // Load pdf file contents by page, ignoring any errors
/// let contents: Vec<String> = loader
/// .load_with_path()
/// .ignore_errors()
/// .by_page()
///
/// for content in contents {
/// println!("{}", content);
/// }
///
/// Ok(())
/// }
/// ```
///
/// [PdfFileLoader] uses strict typing between the iterator methods to ensure that transitions
/// between different implementations of the loaders and it's methods are handled properly by
/// the compiler.
pub struct PdfFileLoader<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}
impl<'a> PdfFileLoader<'a, Result<PathBuf, PdfLoaderError>> {
/// Loads the contents of the pdfs within the iterator returned by [PdfFileLoader::with_glob]
/// or [PdfFileLoader::with_dir]. Loaded PDF documents are raw PDF instances that can be
/// further processed (by page, etc).
///
/// # Example
/// Load pdfs in directory "tests/data/*.pdf" and return the loaded documents
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load().into_iter();
/// for result in content {
/// match result {
/// Ok((path, doc)) => println!("{:?} {}", path, doc),
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn load(self) -> PdfFileLoader<'a, Result<Document, PdfLoaderError>> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|res| res.load())),
}
}
/// Loads the contents of the pdfs within the iterator returned by [PdfFileLoader::with_glob]
/// or [PdfFileLoader::with_dir]. Loaded PDF documents are raw PDF instances with their path
/// that can be further processed.
///
/// # Example
/// Load pdfs in directory "tests/data/*.pdf" and return the loaded documents
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load_with_path().into_iter();
/// for result in content {
/// match result {
/// Ok((path, doc)) => println!("{:?} {}", path, doc),
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn load_with_path(self) -> PdfFileLoader<'a, Result<(PathBuf, Document), PdfLoaderError>> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|res| res.load_with_path())),
}
}
}
impl<'a> PdfFileLoader<'a, Result<PathBuf, PdfLoaderError>> {
/// Directly reads the contents of the pdfs within the iterator returned by
/// [PdfFileLoader::with_glob] or [PdfFileLoader::with_dir].
///
/// # Example
/// Read pdfs in directory "tests/data/*.pdf" and return the contents of the documents.
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.read_with_path().into_iter();
/// for result in content {
/// match result {
/// Ok((path, content)) => println!("{}", content),
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn read(self) -> PdfFileLoader<'a, Result<String, PdfLoaderError>> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|res| {
let doc = res.load()?;
Ok(doc
.page_iter()
.enumerate()
.map(|(page_no, _)| {
doc.extract_text(&[page_no as u32 + 1])
.map_err(PdfLoaderError::PdfError)
})
.collect::<Result<Vec<String>, PdfLoaderError>>()?
.into_iter()
.collect::<String>())
})),
}
}
/// Directly reads the contents of the pdfs within the iterator returned by
/// [PdfFileLoader::with_glob] or [PdfFileLoader::with_dir] and returns the path along with
/// the content.
///
/// # Example
/// Read pdfs in directory "tests/data/*.pdf" and return the content and paths of the documents.
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.read_with_path().into_iter();
/// for result in content {
/// match result {
/// Ok((path, content)) => println!("{:?} {}", path, content),
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn read_with_path(self) -> PdfFileLoader<'a, Result<(PathBuf, String), PdfLoaderError>> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|res| {
let (path, doc) = res.load_with_path()?;
println!(
"Loaded {:?} PDF: {:?}",
path,
doc.page_iter().collect::<Vec<_>>()
);
let content = doc
.page_iter()
.enumerate()
.map(|(page_no, _)| {
doc.extract_text(&[page_no as u32 + 1])
.map_err(PdfLoaderError::PdfError)
})
.collect::<Result<Vec<String>, PdfLoaderError>>()?
.into_iter()
.collect::<String>();
Ok((path, content))
})),
}
}
}
impl<'a> PdfFileLoader<'a, Document> {
/// Chunks the pages of a loaded document by page, flattened as a single vector.
///
/// # Example
/// Load pdfs in directory "tests/data/*.pdf" and chunk all document into it's pages.
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load().by_page().into_iter();
/// for result in content {
/// match result {
/// Ok(page) => println!("{}", page),
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn by_page(self) -> PdfFileLoader<'a, Result<String, PdfLoaderError>> {
PdfFileLoader {
iterator: Box::new(self.iterator.flat_map(|doc| {
doc.page_iter()
.enumerate()
.map(|(page_no, _)| {
doc.extract_text(&[page_no as u32 + 1])
.map_err(PdfLoaderError::PdfError)
})
.collect::<Vec<_>>()
})),
}
}
}
type ByPage = (PathBuf, Vec<(usize, Result<String, PdfLoaderError>)>);
impl<'a> PdfFileLoader<'a, (PathBuf, Document)> {
/// Chunks the pages of a loaded document by page, processed as a vector of documents by path
/// which each document container an inner vector of pages by page number.
///
/// # Example
/// Read pdfs in directory "tests/data/*.pdf" and chunk all documents by path by it's pages.
///
/// ```rust
/// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?
/// .load_with_path()
/// .by_page()
/// .into_iter();
///
/// for result in content {
/// match result {
/// Ok(documents) => {
/// for doc in documents {
/// match doc {
/// Ok((pageno, content)) => println!("Page {}: {}", pageno, content),
/// Err(e) => eprintln!("Error reading page: {}", e),
/// }
/// }
/// },
/// Err(e) => eprintln!("Error reading pdf: {}", e),
/// }
/// }
/// ```
pub fn by_page(self) -> PdfFileLoader<'a, ByPage> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|(path, doc)| {
(
path,
doc.page_iter()
.enumerate()
.map(|(page_no, _)| {
(
page_no,
doc.extract_text(&[page_no as u32 + 1])
.map_err(PdfLoaderError::PdfError),
)
})
.collect::<Vec<_>>(),
)
})),
}
}
}
impl<'a> PdfFileLoader<'a, ByPage> {
/// Ignores errors in the iterator, returning only successful results. This can be used on any
/// [PdfFileLoader] state of iterator whose items are results.
///
/// # Example
/// Read files in directory "tests/data/*.pdf" and ignore errors from unreadable files.
///
/// ```rust
/// let content = FileLoader::with_glob("tests/data/*.pdf")?.read().ignore_errors().into_iter();
/// for result in content {
/// println!("{}", content)
/// }
/// ```
pub fn ignore_errors(self) -> PdfFileLoader<'a, (PathBuf, Vec<(usize, String)>)> {
PdfFileLoader {
iterator: Box::new(self.iterator.map(|(path, pages)| {
let pages = pages
.into_iter()
.filter_map(|(page_no, res)| res.ok().map(|content| (page_no, content)))
.collect::<Vec<_>>();
(path, pages)
})),
}
}
}
impl<'a, T: 'a> PdfFileLoader<'a, Result<T, PdfLoaderError>> {
/// Ignores errors in the iterator, returning only successful results. This can be used on any
/// [PdfFileLoader] state of iterator whose items are results.
///
/// # Example
/// Read files in directory "tests/data/*.pdf" and ignore errors from unreadable files.
///
/// ```rust
/// let content = FileLoader::with_glob("tests/data/*.pdf")?.read().ignore_errors().into_iter();
/// for result in content {
/// println!("{}", content)
/// }
/// ```
pub fn ignore_errors(self) -> PdfFileLoader<'a, T> {
PdfFileLoader {
iterator: Box::new(self.iterator.filter_map(|res| res.ok())),
}
}
}
impl<'a> PdfFileLoader<'a, Result<PathBuf, FileLoaderError>> {
/// Creates a new [PdfFileLoader] using a glob pattern to match files.
///
/// # Example
/// Create a [PdfFileLoader] for all `.pdf` files that match the glob "tests/data/*.pdf".
///
/// ```rust
/// let loader = FileLoader::with_glob("tests/data/*.txt")?;
/// ```
pub fn with_glob(
pattern: &str,
) -> Result<PdfFileLoader<Result<PathBuf, PdfLoaderError>>, PdfLoaderError> {
let paths = glob(pattern).map_err(FileLoaderError::PatternError)?;
Ok(PdfFileLoader {
iterator: Box::new(paths.into_iter().map(|path| {
path.map_err(FileLoaderError::GlobError)
.map_err(PdfLoaderError::FileLoaderError)
})),
})
}
/// Creates a new [PdfFileLoader] on all files within a directory.
///
/// # Example
/// Create a [PdfFileLoader] for all files that are in the directory "files".
///
/// ```rust
/// let loader = PdfFileLoader::with_dir("files")?;
/// ```
pub fn with_dir(
directory: &str,
) -> Result<PdfFileLoader<Result<PathBuf, PdfLoaderError>>, PdfLoaderError> {
Ok(PdfFileLoader {
iterator: Box::new(
fs::read_dir(directory)
.map_err(FileLoaderError::IoError)?
.map(|entry| Ok(entry.map_err(FileLoaderError::IoError)?.path())),
),
})
}
}
// ================================================================
// PDFFileLoader iterator implementations
// ================================================================
pub struct IntoIter<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}
impl<'a, T> IntoIterator for PdfFileLoader<'a, T> {
type Item = T;
type IntoIter = IntoIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
iterator: self.iterator,
}
}
}
impl<'a, T> Iterator for IntoIter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.iterator.next()
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::PdfFileLoader;
#[test]
fn test_pdf_loader() {
let loader = PdfFileLoader::with_glob("tests/data/*.pdf").unwrap();
let actual = loader
.load_with_path()
.ignore_errors()
.by_page()
.ignore_errors()
.into_iter()
.collect::<Vec<_>>();
let mut actual = actual
.into_iter()
.map(|result| {
let (path, pages) = result;
pages.iter().for_each(|(page_no, content)| {
println!("{:?} Page {}: {:?}", path, page_no, content);
});
(path, pages)
})
.collect::<Vec<_>>();
let mut expected = vec![
(
PathBuf::from("tests/data/dummy.pdf"),
vec![(0, "Test\nPDF\nDocument\n".to_string())],
),
(
PathBuf::from("tests/data/pages.pdf"),
vec![
(0, "Page\n1\n".to_string()),
(1, "Page\n2\n".to_string()),
(2, "Page\n3\n".to_string()),
],
),
];
actual.sort();
expected.sort();
assert!(!actual.is_empty());
assert!(expected == actual)
}
}

View File

@ -113,6 +113,13 @@ impl Client {
}
}
/// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set");
ClientBuilder::new(&api_key).build()
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)

View File

@ -47,16 +47,14 @@ pub struct CompletionResponse {
pub enum Content {
String(String),
Text {
r#type: String,
text: String,
#[serde(rename = "type")]
content_type: String,
},
ToolUse {
r#type: String,
id: String,
name: String,
input: String,
#[serde(rename = "type")]
content_type: String,
input: serde_json::Value,
},
}
@ -73,7 +71,6 @@ pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub input_schema: serde_json::Value,
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Deserialize, Serialize)]
@ -94,10 +91,7 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
})
}
[Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::ToolCall(
name.clone(),
serde_json::from_str(input)?,
),
choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()),
raw_response: response,
}),
_ => Err(CompletionError::ResponseError(
@ -157,9 +151,20 @@ impl completion::CompletionModel for CompletionModel {
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Note: Ideally we'd introduce provider-specific Request models to handle the
// specific requirements of each provider. For now, we just manually check while
// building the request as a raw JSON document.
let prompt_with_context = completion_request.prompt_with_context();
let request = json!({
// Check if max_tokens is set, required for Anthropic
if completion_request.max_tokens.is_none() {
return Err(CompletionError::RequestError(
"max_tokens must be set for Anthropic".into(),
));
}
let mut request = json!({
"model": self.model,
"messages": completion_request
.chat_history
@ -172,38 +177,48 @@ impl completion::CompletionModel for CompletionModel {
.collect::<Vec<_>>(),
"max_tokens": completion_request.max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
"temperature": completion_request.temperature,
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
cache_control: None,
})
.collect::<Vec<_>>(),
});
let request = if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request, params.clone())
} else {
request
};
if let Some(temperature) = completion_request.temperature {
json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
}
if !completion_request.tools.is_empty() {
json_utils::merge_inplace(
&mut request,
json!({
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>(),
"tool_choice": ToolChoice::Auto,
}),
);
}
if let Some(ref params) = completion_request.additional_params {
json_utils::merge_inplace(&mut request, params.clone())
}
let response = self
.client
.post("/v1/messages")
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;
match response {
ApiResponse::Message(completion) => completion.try_into(),
ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Message(completion) => completion.try_into(),
ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -57,6 +57,13 @@ impl Client {
}
}
/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
@ -192,8 +199,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
async fn embed_documents(
&self,
documents: Vec<String>,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/v1/embed")
@ -203,32 +212,33 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
"input_type": self.input_type,
}))
.send()
.await?
.error_for_status()?
.json::<ApiResponse<EmbeddingResponse>>()
.await?;
match response {
ApiResponse::Ok(response) => {
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)));
}
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)));
}
Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
@ -500,14 +510,15 @@ impl completion::CompletionModel for CompletionModel {
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;
match response {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -241,8 +241,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
async fn embed_documents(
&self,
documents: Vec<String>,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/v1/embeddings")
@ -251,30 +253,31 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
"input": documents,
}))
.send()
.await?
.error_for_status()?
.json::<ApiResponse<EmbeddingResponse>>()
.await?;
match response {
ApiResponse::Ok(response) => {
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
}
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
}
Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
@ -510,14 +513,15 @@ impl completion::CompletionModel for CompletionModel {
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;
match response {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -231,14 +231,15 @@ impl completion::CompletionModel for CompletionModel {
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;
match response {
ApiResponse::Ok(completion) => Ok(completion.try_into()?),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.try_into()?),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

Binary file not shown.

Binary file not shown.

55
rig-lancedb/CHANGELOG.md Normal file
View File

@ -0,0 +1,55 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [0.1.0](https://github.com/0xPlaygrounds/rig/releases/tag/rig-lancedb-v0.1.0) - 2024-10-24
### Added
- update examples to use new version of VectorStoreIndex trait
- replace document embeddings with serde json value
- merge all arrow columns into JSON document in deserializer
- finish implementing deserialiser for record batch
- implement deserialization for any recordbatch returned from lanceDB
- add indexes and tables for simple search
- create enum for embedding models
- add vector_search_s3_ann example
- implement ANN search example
- start implementing top_n_from_query for trait VectorStoreIndex
- implement get_document method of VectorStore trait
- implement search by id for VectorStore trait
- implement add_documents on VectorStore trait
- start implementing VectorStore trait for lancedb
### Fixed
- update lancedb examples test data
- make PR changes Pt II
- make PR changes pt I
- *(lancedb)* replace VectorStoreIndexDyn with VectorStoreIndex in examples
- mongodb vector search - use num_candidates from search params
- fix bug in deserializing type run end
- make PR requested changes
- reduce opanai generated content in ANN examples
### Other
- cargo fmt
- lance db examples
- add example docstring
- add doc strings
- update rig core version on lancedb crate, remove implementation of VectorStore trait
- remove print statement
- use constants instead of enum for model names
- remove associated type on VectorStoreIndex trait
- cargo fmt
- conversions from arrow types to primitive types
- Add doc strings to utility methods
- add doc string to mongodb search params struct
- Merge branch 'main' into feat(vector-store)/lancedb
- create wrapper for vec<DocumentEmbeddings> for from/tryfrom traits

View File

@ -2,10 +2,14 @@
name = "rig-lancedb"
version = "0.1.0"
edition = "2021"
license = "MIT"
readme = "README.md"
description = "Rig vector store index integration for LanceDB."
repository = "https://github.com/0xPlaygrounds/rig"
[dependencies]
lancedb = "0.10.0"
rig-core = { path = "../rig-core", version = "0.2.1" }
rig-core = { path = "../rig-core", version = "0.3.0" }
arrow-array = "52.2.0"
serde_json = "1.0.128"
serde = "1.0.210"

7
rig-lancedb/LICENSE Normal file
View File

@ -0,0 +1,7 @@
Copyright (c) 2024, Playgrounds Analytics Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

2
rig-lancedb/README.md Normal file
View File

@ -0,0 +1,2 @@
# Rig-lancedb
Vector store index integration for LanceDB

View File

@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.1.3](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.2...rig-mongodb-v0.1.3) - 2024-10-24
### Fixed
- make PR changes pt I
- mongodb vector search - use num_candidates from search params
### Other
- Merge branch 'main' into feat(vector-store)/lancedb
## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.1...rig-mongodb-v0.1.2) - 2024-10-01
### Other

View File

@ -1,6 +1,6 @@
[package]
name = "rig-mongodb"
version = "0.1.2"
version = "0.1.3"
edition = "2021"
license = "MIT"
readme = "README.md"
@ -12,7 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig"
[dependencies]
futures = "0.3.30"
mongodb = "2.8.2"
rig-core = { path = "../rig-core", version = "0.2.1" }
rig-core = { path = "../rig-core", version = "0.3.0" }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.117"
tracing = "0.1.40"

View File

@ -1,2 +1,34 @@
# Rig-mongodb
This project implements a Rig vector store based on MongoDB.
<div style="display: flex; align-items: center; justify-content: center;">\
<picture>
<source media="(prefers-color-scheme: dark)" srcset="../img/rig_logo_dark.svg">
<source media="(prefers-color-scheme: light)" srcset="../img/rig_logo.svg">
<img src="../img/rig_logo.svg" width="200" alt="Rig logo">
</picture>
<span style="font-size: 48px; margin: 0 20px; font-weight: regular; font-family: Open Sans, sans-serif;"> + </span>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://companieslogo.com/img/orig/MDB_BIG.D-96d632a9.png?t=1720244492">
<source media="(prefers-color-scheme: light)" srcset="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256">
<img src="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256" width="200" alt="MongoDB logo">
</picture>
</div>
<br><br>
## Rig-MongoDB
This companion crate implements a Rig vector store based on MongoDB.
## Usage
Add the companion crate to your `Cargo.toml`, along with the rig-core crate:
```toml
[dependencies]
rig-mongodb = "0.1.2"
rig-core = "0.2.1"
```
You can also run `cargo add rig-mongodb rig-core` to add the most recent versions of the dependencies to your project.
See the [examples](./examples) folder for usage examples.

View File

@ -1,12 +1,12 @@
use mongodb::{options::ClientOptions, Client as MongoClient, Collection};
use std::env;
use rig::vector_store::VectorStore;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{VectorStore, VectorStoreIndex},
vector_store::VectorStoreIndex,
};
use rig_mongodb::{MongoDbVectorStore, SearchParams};
use std::env;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {