This commit is contained in:
Joshua Mo 2025-04-14 15:37:03 +00:00 committed by GitHub
commit 1bc59da156
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1491 additions and 1 deletions

1
Cargo.lock generated
View File

@ -8612,6 +8612,7 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"tokio-test",
"tower 0.5.2",
"tracing",
"tracing-subscriber",
"worker",

View File

@ -9,6 +9,7 @@ members = [
"rig-qdrant",
"rig-core/rig-core-derive",
"rig-sqlite",
"rig-eternalai", "rig-fastembed",
"rig-eternalai",
"rig-fastembed",
"rig-surrealdb",
]

View File

@ -39,6 +39,7 @@ bytes = "1.9.0"
async-stream = "0.3.6"
mime_guess = { version = "2.0.5" }
base64 = { version = "0.22.1" }
tower = "0.5.2"
[dev-dependencies]

View File

@ -0,0 +1,80 @@
use rig::{
completion::{CompletionRequestBuilder, ToolDefinition},
middlewares::{
completion::{CompletionLayer, CompletionService},
tools::ToolLayer,
},
providers::openai::Client,
tool::{Tool, ToolSet},
};
use serde::{Deserialize, Serialize};
use tower::Service;
use tower::ServiceBuilder;
#[tokio::main]
async fn main() {
let client = Client::from_env();
let model = client.completion_model("gpt-4o");
let comp_layer = CompletionLayer::builder(model.clone()).build();
let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add]));
let service = CompletionService::new(model.clone());
let mut service = ServiceBuilder::new()
.layer(comp_layer)
.layer(tool_layer)
.service(service);
let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build();
let res = service.call(comp_request).await.unwrap();
println!("{res:?}");
}
#[derive(Deserialize, Serialize)]
struct Add;
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
impl Tool for Add {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(serde_json::json!({
"name": "add",
"description": "Add x and y together",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}

View File

@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync {
}
/// Struct representing a general completion request that can be sent to a completion model provider.
#[derive(Clone)]
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: Message,

View File

@ -91,6 +91,7 @@ pub mod extractor;
pub mod image_generation;
pub(crate) mod json_utils;
pub mod loaders;
pub mod middlewares;
pub mod one_or_many;
pub mod pipeline;
pub mod providers;

View File

@ -0,0 +1,532 @@
use std::{future::Future, pin::Pin, task::Poll};
use tower::{Layer, Service};
use crate::{
completion::{
CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition,
},
message::Message,
};
/// A Tower layer to finish building your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct FinishBuilding;
impl<S> Layer<S> for FinishBuilding {
type Service = FinishBuildingService<S>;
fn layer(&self, inner: S) -> Self::Service {
FinishBuildingService { inner }
}
}
/// A Tower layer to finish building your `CompletionRequestBuilder`.
/// Not intended to be used directly. Use [`FinishBuilding`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct FinishBuildingService<S> {
inner: S,
}
impl<M, Msg, S> Service<(M, Msg)> for FinishBuildingService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequest;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.build())
})
}
}
/// A Tower layer to add documents to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct DocumentsLayer {
documents: Vec<Document>,
}
impl DocumentsLayer {
pub fn new(documents: Vec<Document>) -> Self {
Self { documents }
}
}
impl<S> Layer<S> for DocumentsLayer {
type Service = DocumentsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
DocumentsLayerService {
inner,
documents: self.documents.clone(),
}
}
}
/// A Tower service to add documents to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`DocumentsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct DocumentsLayerService<S> {
inner: S,
documents: Vec<Document>,
}
impl<M, Msg, S> Service<(M, Msg)> for DocumentsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let documents = self.documents.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.documents(documents))
})
}
}
/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct TemperatureLayer {
temperature: f64,
}
impl TemperatureLayer {
pub fn new(temperature: f64) -> Self {
Self { temperature }
}
}
impl<S> Layer<S> for TemperatureLayer {
type Service = TemperatureLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
TemperatureLayerService {
inner,
temperature: self.temperature,
}
}
}
/// A Tower service to add a temperature value to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`TemperatureLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct TemperatureLayerService<S> {
inner: S,
temperature: f64,
}
impl<M, Msg, S> Service<(M, Msg)> for TemperatureLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let temperature = self.temperature;
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.temperature(temperature))
})
}
}
/// A Tower service to add tools to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ToolsLayer {
tools: Vec<ToolDefinition>,
}
impl ToolsLayer {
pub fn new(tools: Vec<ToolDefinition>) -> Self {
Self { tools }
}
}
impl<S> Layer<S> for ToolsLayer {
type Service = ToolsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
ToolsLayerService {
inner,
tools: self.tools.clone(),
}
}
}
/// A Tower service to add tools to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`ToolsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ToolsLayerService<S> {
inner: S,
tools: Vec<ToolDefinition>,
}
impl<M, Msg, S> Service<(M, Msg)> for ToolsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let tools = self.tools.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.tools(tools))
})
}
}
/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct PreambleLayer {
preamble: String,
}
impl PreambleLayer {
pub fn new(preamble: String) -> Self {
Self { preamble }
}
}
impl<S> Layer<S> for PreambleLayer {
type Service = PreambleLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
PreambleLayerService {
inner,
preamble: self.preamble.clone(),
}
}
}
/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`PreambleLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct PreambleLayerService<S> {
inner: S,
preamble: String,
}
impl<M, Msg, S> Service<(M, Msg)> for PreambleLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let preamble = self.preamble.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.preamble(preamble))
})
}
}
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct AdditionalParamsLayer {
additional_params: serde_json::Value,
}
impl AdditionalParamsLayer {
pub fn new(additional_params: serde_json::Value) -> Self {
Self { additional_params }
}
}
impl<S> Layer<S> for AdditionalParamsLayer {
type Service = AdditionalParamsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
AdditionalParamsLayerService {
inner,
additional_params: self.additional_params.clone(),
}
}
}
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct AdditionalParamsLayerService<S> {
inner: S,
additional_params: serde_json::Value,
}
impl<M, Msg, S> Service<(M, Msg)> for AdditionalParamsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let additional_params = self.additional_params.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.additional_params(additional_params))
})
}
}
/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct MaxTokensLayer {
max_tokens: u64,
}
impl MaxTokensLayer {
pub fn new(max_tokens: u64) -> Self {
Self { max_tokens }
}
}
impl<S> Layer<S> for MaxTokensLayer {
type Service = MaxTokensLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
MaxTokensLayerService {
inner,
max_tokens: self.max_tokens,
}
}
}
/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`MaxTokensLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct MaxTokensLayerService<S> {
inner: S,
max_tokens: u64,
}
impl<M, Msg, S> Service<(M, Msg)> for MaxTokensLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let max_tokens = self.max_tokens;
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.max_tokens(max_tokens))
})
}
}
/// A Tower layer to add a chat history to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ChatHistoryLayer {
chat_history: Vec<Message>,
}
impl ChatHistoryLayer {
pub fn new(chat_history: Vec<Message>) -> Self {
Self { chat_history }
}
}
impl<S> Layer<S> for ChatHistoryLayer {
type Service = ChatHistoryLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
ChatHistoryLayerService {
inner,
chat_history: self.chat_history.clone(),
}
}
}
/// A Tower service to add a chat history to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`ChatHistoryLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ChatHistoryLayerService<S> {
inner: S,
chat_history: Vec<Message>,
}
impl<M, Msg, S> Service<(M, Msg)> for ChatHistoryLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let chat_history = self.chat_history.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.messages(chat_history))
})
}
}
/// A `tower` service for building a completion request.
/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top).
///
/// Usage:
/// ```rust
/// let agent = rig::providers::openai::Client::from_env();
/// let model = agent.completion_model("gpt-4o");
///
/// let service = tower::ServiceBuilder::new()
/// .layer(FinishBuilding)
/// .layer(TemperatureLayer::new(0.0))
/// .layer(PreambleLayer::new("You are a helpful assistant"))
/// .service(CompletionRequestBuilderService);
///
/// let request = service.call((model, "Hello world!".to_string())).await.unwrap();
///
/// let res = request.send().await.unwrap();
///
/// println!("{res:?}");
/// ```
pub struct CompletionRequestBuilderService;
impl<M, Msg> Service<(M, Msg)> for CompletionRequestBuilderService
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future {
Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) })
}
}

View File

@ -0,0 +1,262 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
use crate::{
completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
CompletionResponse, Document, ToolDefinition,
},
message::{Message, ToolResultContent, UserContent},
OneOrMany,
};
#[derive(Clone)]
pub struct CompletionLayer<M> {
model: M,
preamble: Option<String>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}
impl<M> CompletionLayer<M>
where
M: CompletionModel,
{
pub fn builder(model: M) -> CompletionLayerBuilder<M> {
CompletionLayerBuilder::new(model)
}
}
#[derive(Default)]
pub struct CompletionLayerBuilder<M> {
model: M,
preamble: Option<String>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}
impl<M> CompletionLayerBuilder<M>
where
M: CompletionModel,
{
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
documents: vec![],
tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
}
}
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = Some(preamble);
self
}
pub fn preamble_opt(mut self, preamble: Option<String>) -> Self {
self.preamble = preamble;
self
}
pub fn documents(mut self, documents: Vec<Document>) -> Self {
self.documents = documents;
self
}
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn additional_params_opt(mut self, params: Option<serde_json::Value>) -> Self {
self.additional_params = params;
self
}
pub fn build(self) -> CompletionLayer<M> {
CompletionLayer {
model: self.model,
preamble: self.preamble,
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
}
}
}
impl<M, S> Layer<S> for CompletionLayer<M>
where
M: CompletionModel,
{
type Service = CompletionLayerService<M, S>;
fn layer(&self, inner: S) -> Self::Service {
CompletionLayerService {
inner,
model: self.model.clone(),
preamble: self.preamble.clone(),
documents: self.documents.clone(),
tools: self.tools.clone(),
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params.clone(),
}
}
}
#[derive(Clone)]
pub struct CompletionLayerService<M, S> {
inner: S,
model: M,
preamble: Option<String>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}
impl<M, S> Service<CompletionRequest> for CompletionLayerService<M, S>
where
M: CompletionModel + 'static,
S: Service<CompletionRequest, Response = (Vec<Message>, String, ToolResultContent)>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = CompletionResponse<M::Response>;
type Error = CompletionError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut inner = self.inner.clone();
let model = self.model.clone();
let preamble = self.preamble.clone();
let documents = self.documents.clone();
let temperature = self.temperature;
let tools = self.tools.clone();
let max_tokens = self.max_tokens;
let additional_params = self.additional_params.clone();
Box::pin(async move {
let Ok((messages, id, tool_content)) = inner.call(req).await else {
todo!("Handle error properly");
};
let tool_result_message = Message::User {
content: OneOrMany::one(UserContent::tool_result(id, OneOrMany::one(tool_content))),
};
let mut req = CompletionRequestBuilder::new(model.clone(), tool_result_message)
.documents(documents.clone())
.tools(tools.clone())
.messages(messages)
.temperature_opt(temperature)
.max_tokens_opt(max_tokens)
.additional_params_opt(additional_params.clone());
if let Some(preamble) = preamble.clone() {
req = req.preamble(preamble);
}
let req = req.build();
model.completion(req).await
})
}
}
/// A completion model as a Tower service.
///
/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service.
#[derive(Clone)]
pub struct CompletionService<M> {
/// The model itself.
model: M,
}
impl<M> CompletionService<M>
where
M: CompletionModel,
{
pub fn new(model: M) -> Self {
Self { model }
}
}
impl<M> Service<CompletionRequest> for CompletionService<M>
where
M: CompletionModel + 'static,
{
type Response = CompletionResponse<M::Response>;
type Error = CompletionError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let model = self.model.clone();
Box::pin(async move { model.completion(req).await })
}
}

View File

@ -0,0 +1,247 @@
use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, task::Poll};
use tower::{Layer, Service};
use crate::{
completion::{CompletionRequest, CompletionResponse},
message::{AssistantContent, Text},
};
#[derive(Clone)]
pub struct AwaitApprovalLayer<T> {
integration: T,
}
impl<T> AwaitApprovalLayer<T>
where
T: HumanInTheLoop + Clone,
{
pub fn new(integration: T) -> Self {
Self { integration }
}
pub fn with_predicate<R, D>(self, predicate: R) -> AwaitApprovalLayerWithPredicate<D, R, T>
where
D: Debug,
R: Fn() -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
{
AwaitApprovalLayerWithPredicate {
integration: self.integration,
predicate,
_t: PhantomData,
}
}
}
impl<S, T> Layer<S> for AwaitApprovalLayer<T>
where
T: HumanInTheLoop + Clone,
{
type Service = AwaitApprovalLayerService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
AwaitApprovalLayerService::new(inner, self.integration.clone())
}
}
#[derive(Clone)]
pub struct AwaitApprovalLayerService<S, T> {
inner: S,
integration: T,
}
impl<S, T> AwaitApprovalLayerService<S, T>
where
T: HumanInTheLoop,
{
pub fn new(inner: S, integration: T) -> Self {
Self { inner, integration }
}
}
impl<S, T, Response> Service<CompletionRequest> for AwaitApprovalLayerService<S, T>
where
S: Service<CompletionRequest, Response = CompletionResponse<Response>> + Clone + Send + 'static,
S::Future: Send,
Response: Clone + 'static + Send,
T: HumanInTheLoop + Clone + Send + 'static,
{
type Response = CompletionResponse<Response>;
type Error = bool;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut inner = self.inner.clone();
let await_approval_loop = self.integration.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
let AssistantContent::Text(Text { text }) = res.choice.first() else {
todo!("Handle error properly");
};
if await_approval_loop.send_message(&text).await.is_err() {
todo!("Handle error properly");
}
let Ok(bool) = await_approval_loop.await_approval().await else {
todo!("Handle error properly");
};
if bool {
Ok(res)
} else {
todo!("Handle error properly - we should abort the pipeline here if the user wants to abort");
}
})
}
}
pub struct AwaitApprovalLayerWithPredicate<P, R, T> {
integration: T,
predicate: R,
_t: PhantomData<P>,
}
impl<D, R, S, T> Layer<S> for AwaitApprovalLayerWithPredicate<D, R, T>
where
T: HumanInTheLoop + Clone,
D: Debug,
R: Fn(&D) -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
{
type Service = AwaitApprovalLayerServiceWithPredicate<D, R, S, T>;
fn layer(&self, inner: S) -> Self::Service {
let predicate = self.predicate.clone();
AwaitApprovalLayerServiceWithPredicate::new(inner, self.integration.clone(), predicate)
}
}
pub struct AwaitApprovalLayerServiceWithPredicate<D, R, S, T> {
inner: S,
integration: T,
predicate: R,
_t: PhantomData<D>,
}
impl<D, R, S, T> AwaitApprovalLayerServiceWithPredicate<D, R, S, T>
where
T: HumanInTheLoop,
R: Fn(&D) -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
D: Debug,
{
pub fn new(inner: S, integration: T, predicate: R) -> Self {
Self {
inner,
integration,
predicate,
_t: PhantomData,
}
}
}
impl<D, S, T, Response, R> Service<CompletionRequest>
for AwaitApprovalLayerServiceWithPredicate<D, R, S, T>
where
R: Fn(&CompletionResponse<Response>) -> Pin<Box<dyn Future<Output = bool> + Send>>
+ Clone
+ Send
+ 'static,
S: Service<CompletionRequest, Response = CompletionResponse<Response>> + Clone + Send + 'static,
S::Future: Send,
Response: Clone + 'static + Send,
T: HumanInTheLoop + Clone + Send + 'static,
{
type Response = CompletionResponse<Response>;
type Error = bool;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut inner = self.inner.clone();
let await_approval_loop = self.integration.clone();
let predicate = self.predicate.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
if predicate(&res).await {
return Ok(res);
}
let AssistantContent::Text(Text { text }) = res.choice.first() else {
todo!("Handle error properly");
};
if await_approval_loop.send_message(&text).await.is_err() {
todo!("Handle error properly");
}
let Ok(bool) = await_approval_loop.await_approval().await else {
todo!("Handle error properly");
};
if bool {
Ok(res)
} else {
todo!("Handle error properly - we should abort the pipeline here if the user wants to abort");
}
})
}
}
pub trait HumanInTheLoop {
fn send_message(
&self,
res: &str,
) -> impl Future<Output = Result<(), Box<dyn std::error::Error>>> + Send;
fn await_approval(
&self,
) -> impl Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send;
}
pub struct Stdout;
impl HumanInTheLoop for Stdout {
async fn send_message(&self, res: &str) -> Result<(), Box<dyn std::error::Error>> {
print!(
"Current result: {res}
Would you like to approve this step? [Y/n]"
);
Ok(())
}
async fn await_approval(&self) -> Result<bool, Box<dyn std::error::Error>> {
let mut string = String::new();
loop {
std::io::stdin().read_line(&mut string).unwrap();
match string.to_lowercase().trim() {
"y" | "yes" => break Ok(true),
"n" | "no" => break Ok(false),
_ => println!("Please respond with 'y' or 'n'."),
}
}
}
}

View File

@ -0,0 +1,87 @@
use serde::Deserialize;
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
use tower::{Layer, Service};
use crate::{
completion::{CompletionRequest, CompletionResponse},
message::{AssistantContent, Text},
};
use super::ServiceError;
#[derive(Clone)]
pub struct ExtractorLayer<T> {
_t: PhantomData<T>,
}
impl<T> ExtractorLayer<T>
where
T: for<'a> Deserialize<'a>,
{
pub fn new() -> Self {
Self { _t: PhantomData }
}
}
impl<T> Default for ExtractorLayer<T>
where
T: for<'a> Deserialize<'a>,
{
fn default() -> Self {
Self::new()
}
}
impl<S, T> Layer<S> for ExtractorLayer<T>
where
T: for<'a> Deserialize<'a>,
{
type Service = ExtractorLayerService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
ExtractorLayerService { inner, _t: self._t }
}
}
#[derive(Clone)]
pub struct ExtractorLayerService<S, T> {
inner: S,
_t: PhantomData<T>,
}
impl<S, F, T> Service<CompletionRequest> for ExtractorLayerService<S, T>
where
S: Service<CompletionRequest, Response = CompletionResponse<F>, Error = ServiceError>
+ Clone
+ Send
+ 'static,
S::Future: Send,
F: 'static,
T: for<'a> Deserialize<'a> + 'static,
{
type Response = T;
type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
let res = inner.call(req).await?;
let AssistantContent::Text(Text { text }) = res.choice.first() else {
todo!("Handle errors properly");
};
let obj = serde_json::from_str::<T>(&text)?;
Ok(obj)
})
}
}

View File

@ -0,0 +1,39 @@
use thiserror::Error;
use crate::{
completion::CompletionError, extractor::ExtractionError, tool::ToolSetError,
vector_store::VectorStoreError,
};
pub mod build_completions;
pub mod completion;
pub mod components;
pub mod extractor;
pub mod parallel;
pub mod rag;
pub mod tools;
#[derive(Debug, Error)]
pub enum ServiceError {
#[error("{0}")]
ExtractionError(#[from] ExtractionError),
#[error("{0}")]
CompletionError(#[from] CompletionError),
#[error("{0}")]
ToolSetError(#[from] ToolSetError),
#[error("{0}")]
VectorStoreError(#[from] VectorStoreError),
#[error("Value required but was null: {0}")]
RequiredOptionNotFound(String),
#[error("{0}")]
Json(#[from] serde_json::Error),
#[error("Custom error: {0}")]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
impl ServiceError {
pub fn required_option_not_exists<S: Into<String>>(val: S) -> Self {
let val: String = val.into();
Self::RequiredOptionNotFound(val)
}
}

View File

@ -0,0 +1,88 @@
use std::{future::Future, pin::Pin};
use tower::Service;
use crate::completion::CompletionRequest;
use super::ServiceError;
pub struct Stackable<A, B> {
pub inner: A,
pub outer: B,
}
impl<A, B> Stackable<A, B> {
pub fn new(inner: A, outer: B) -> Self {
Self { inner, outer }
}
pub fn take_values(self) -> (A, B) {
(self.inner, self.outer)
}
}
#[derive(Clone)]
pub struct ParallelService<S, T> {
first_service: S,
second_service: T,
}
impl<S, T> ParallelService<S, T>
where
S: Service<CompletionRequest>,
T: Service<CompletionRequest>,
{
pub fn new(first_service: S, second_service: T) -> Self {
Self {
first_service,
second_service,
}
}
}
impl<S, T> Service<CompletionRequest> for ParallelService<S, T>
where
S: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
S::Future: Send,
S::Response: Send + 'static,
T: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
T::Future: Send,
T::Response: Send + 'static,
{
type Response = Stackable<S::Response, T::Response>;
type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut first = self.first_service.clone();
let mut second = self.second_service.clone();
Box::pin(async move {
let res1 = first.call(req.clone()).await?;
let res2 = second.call(req.clone()).await?;
let stackable = Stackable::new(res1, res2);
Ok(stackable)
})
}
}
#[macro_export]
macro_rules! parallel_service {
($service1:tt, $service2:tt) => {
$crate::pipeline::parallel::ParallelService::new($service1, $service2)
};
($op1:tt $(, $ops:tt)*) => {
$crate::pipeline::parallel::ParallelService::new(
$service1,
$crate::parallel_op!($($ops),*)
)
};
}

View File

@ -0,0 +1,64 @@
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use serde::{Deserialize, Serialize};
use tower::Service;
use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex};
use super::ServiceError;
pub struct RagService<V, T> {
vector_index: Arc<V>,
num_results: usize,
_phantom: PhantomData<T>,
}
impl<V, T> RagService<V, T>
where
V: VectorStoreIndex,
{
pub fn new(vector_index: V, num_results: usize) -> Self {
Self {
vector_index: Arc::new(vector_index),
num_results,
_phantom: PhantomData,
}
}
}
impl<V, T> Service<CompletionRequest> for RagService<V, T>
where
V: VectorStoreIndex + 'static,
T: Serialize + for<'a> Deserialize<'a> + Send,
{
type Response = RagResult<T>;
type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let vector_index = self.vector_index.clone();
let num_results = self.num_results;
Box::pin(async move {
let Some(prompt) = req.prompt.rag_text() else {
return Err(ServiceError::required_option_not_exists("rag_text"));
};
let res = vector_index.top_n(&prompt, num_results).await?;
Ok(res)
})
}
}
pub type RagResult<T> = Vec<(f64, String, T)>;

View File

@ -0,0 +1,86 @@
use crate::{
completion::{CompletionRequest, CompletionResponse},
message::{AssistantContent, Message, ToolResultContent},
tool::{ToolSet, ToolSetError},
OneOrMany,
};
use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct ToolLayer {
tools: Arc<ToolSet>,
}
impl ToolLayer {
pub fn new(tools: ToolSet) -> Self {
Self {
tools: Arc::new(tools),
}
}
}
impl<S> Layer<S> for ToolLayer {
type Service = ToolLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
ToolLayerService {
inner,
tools: Arc::clone(&self.tools),
}
}
}
#[derive(Clone)]
pub struct ToolLayerService<S> {
inner: S,
tools: Arc<ToolSet>,
}
impl<S, T> Service<CompletionRequest> for ToolLayerService<S>
where
S: Service<CompletionRequest, Response = CompletionResponse<T>> + Clone + Send + 'static,
T: Send + 'static,
S::Future: Send,
{
type Response = (Vec<Message>, String, ToolResultContent);
type Error = ToolSetError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut inner = self.inner.clone();
let tools = self.tools.clone();
let mut messages = req.chat_history.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
let AssistantContent::ToolCall(tool_call) = res.choice.first() else {
todo!("Handle error properly");
};
messages.push(Message::Assistant {
content: OneOrMany::one(AssistantContent::ToolCall(tool_call.clone())),
});
let Ok(res) = tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await
else {
todo!("Implement proper error handling");
};
Ok((messages, tool_call.id, ToolResultContent::text(res)))
})
}
}