feat: a bunch of stuff

This commit is contained in:
Joshua Mo 2025-04-09 17:15:24 +01:00
parent ed0baca252
commit 49221ab848
9 changed files with 748 additions and 143 deletions

View File

@ -0,0 +1,67 @@
use rig::{
completion::CompletionRequestBuilder,
middlewares::{
completion::{CompletionLayer, CompletionService},
tools::ToolLayer,
},
providers::openai::Client,
};
use tower::ServiceBuilder;
#[tokio::main]
async fn main() {
let client = Client::from_env();
let model = client.completion_model("gpt-4o");
let comp_layer = CompletionLayer::builder(model).build();
let tool_layer = ToolLayer::new(vec![Add]);
let service = CompletionService::new(model);
let service = ServiceBuilder::new()
.layer(comp_layer)
.layer(tool_layer)
.service(service);
let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build();
let res = service.call(comp_request).await.unwrap();
println!("{res:?}");
}
#[derive(Deserialize, Serialize)]
struct Add;
impl Tool for Add {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "add",
"description": "Add x and y together",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}

View File

@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync {
} }
/// Struct representing a general completion request that can be sent to a completion model provider. /// Struct representing a general completion request that can be sent to a completion model provider.
#[derive(Clone)]
pub struct CompletionRequest { pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider /// The prompt to be sent to the completion model provider
pub prompt: Message, pub prompt: Message,

View File

@ -0,0 +1,532 @@
use std::{future::Future, pin::Pin, task::Poll};
use tower::{Layer, Service};
use crate::{
completion::{
CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition,
},
message::Message,
};
/// A Tower layer to finish building your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct FinishBuilding;
impl<S> Layer<S> for FinishBuilding {
type Service = FinishBuildingService<S>;
fn layer(&self, inner: S) -> Self::Service {
FinishBuildingService { inner }
}
}
/// A Tower layer to finish building your `CompletionRequestBuilder`.
/// Not intended to be used directly. Use [`FinishBuilding`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct FinishBuildingService<S> {
inner: S,
}
impl<M, Msg, S> Service<(M, Msg)> for FinishBuildingService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequest;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.build())
})
}
}
/// A Tower layer to add documents to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct DocumentsLayer {
documents: Vec<Document>,
}
impl DocumentsLayer {
pub fn new(documents: Vec<Document>) -> Self {
Self { documents }
}
}
impl<S> Layer<S> for DocumentsLayer {
type Service = DocumentsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
DocumentsLayerService {
inner,
documents: self.documents.clone(),
}
}
}
/// A Tower service to add documents to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`DocumentsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct DocumentsLayerService<S> {
inner: S,
documents: Vec<Document>,
}
impl<M, Msg, S> Service<(M, Msg)> for DocumentsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let documents = self.documents.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.documents(documents))
})
}
}
/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct TemperatureLayer {
temperature: f64,
}
impl TemperatureLayer {
pub fn new(temperature: f64) -> Self {
Self { temperature }
}
}
impl<S> Layer<S> for TemperatureLayer {
type Service = TemperatureLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
TemperatureLayerService {
inner,
temperature: self.temperature,
}
}
}
/// A Tower service to add a temperature value to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`TemperatureLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct TemperatureLayerService<S> {
inner: S,
temperature: f64,
}
impl<M, Msg, S> Service<(M, Msg)> for TemperatureLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let temperature = self.temperature;
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.temperature(temperature))
})
}
}
/// A Tower service to add tools to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ToolsLayer {
tools: Vec<ToolDefinition>,
}
impl ToolsLayer {
pub fn new(tools: Vec<ToolDefinition>) -> Self {
Self { tools }
}
}
impl<S> Layer<S> for ToolsLayer {
type Service = ToolsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
ToolsLayerService {
inner,
tools: self.tools.clone(),
}
}
}
/// A Tower service to add tools to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`ToolsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ToolsLayerService<S> {
inner: S,
tools: Vec<ToolDefinition>,
}
impl<M, Msg, S> Service<(M, Msg)> for ToolsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let tools = self.tools.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.tools(tools))
})
}
}
/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct PreambleLayer {
preamble: String,
}
impl PreambleLayer {
pub fn new(preamble: String) -> Self {
Self { preamble }
}
}
impl<S> Layer<S> for PreambleLayer {
type Service = PreambleLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
PreambleLayerService {
inner,
preamble: self.preamble.clone(),
}
}
}
/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`PreambleLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct PreambleLayerService<S> {
inner: S,
preamble: String,
}
impl<M, Msg, S> Service<(M, Msg)> for PreambleLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let preamble = self.preamble.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.preamble(preamble))
})
}
}
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct AdditionalParamsLayer {
additional_params: serde_json::Value,
}
impl AdditionalParamsLayer {
pub fn new(additional_params: serde_json::Value) -> Self {
Self { additional_params }
}
}
impl<S> Layer<S> for AdditionalParamsLayer {
type Service = AdditionalParamsLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
AdditionalParamsLayerService {
inner,
additional_params: self.additional_params.clone(),
}
}
}
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct AdditionalParamsLayerService<S> {
inner: S,
additional_params: serde_json::Value,
}
impl<M, Msg, S> Service<(M, Msg)> for AdditionalParamsLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let additional_params = self.additional_params.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.additional_params(additional_params))
})
}
}
/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct MaxTokensLayer {
max_tokens: u64,
}
impl MaxTokensLayer {
pub fn new(max_tokens: u64) -> Self {
Self { max_tokens }
}
}
impl<S> Layer<S> for MaxTokensLayer {
type Service = MaxTokensLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
MaxTokensLayerService {
inner,
max_tokens: self.max_tokens,
}
}
}
/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`MaxTokensLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct MaxTokensLayerService<S> {
inner: S,
max_tokens: u64,
}
impl<M, Msg, S> Service<(M, Msg)> for MaxTokensLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let max_tokens = self.max_tokens;
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.max_tokens(max_tokens))
})
}
}
/// A Tower layer to add a chat history to your `CompletionRequestBuilder`.
/// Intended to be used with [`CompletionRequestBuilderService`].
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ChatHistoryLayer {
chat_history: Vec<Message>,
}
impl ChatHistoryLayer {
pub fn new(chat_history: Vec<Message>) -> Self {
Self { chat_history }
}
}
impl<S> Layer<S> for ChatHistoryLayer {
type Service = ChatHistoryLayerService<S>;
fn layer(&self, inner: S) -> Self::Service {
ChatHistoryLayerService {
inner,
chat_history: self.chat_history.clone(),
}
}
}
/// A Tower service to add a chat history to your `CompletionRequestBuilder`.
/// Not intended to be used directly - use [`ChatHistoryLayer`] instead.
///
/// See [`CompletionRequestBuilderService`] for usage.
pub struct ChatHistoryLayerService<S> {
inner: S,
chat_history: Vec<Message>,
}
impl<M, Msg, S> Service<(M, Msg)> for ChatHistoryLayerService<S>
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (M, Msg)) -> Self::Future {
let mut inner = self.inner.clone();
let chat_history = self.chat_history.clone();
Box::pin(async move {
let Ok(res) = inner.call(req).await else {
todo!("Handle error properly");
};
Ok(res.messages(chat_history))
})
}
}
/// A `tower` service for building a completion request.
/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top).
///
/// Usage:
/// ```rust
/// let agent = rig::providers::openai::Client::from_env();
/// let model = agent.completion_model("gpt-4o");
///
/// let service = tower::ServiceBuilder::new()
/// .layer(FinishBuilding)
/// .layer(TemperatureLayer::new(0.0))
/// .layer(PreambleLayer::new("You are a helpful assistant"))
/// .service(CompletionRequestBuilderService);
///
/// let request = service.call((model, "Hello world!".to_string())).await.unwrap();
///
/// let res = request.send().await.unwrap();
///
/// println!("{res:?}");
/// ```
pub struct CompletionRequestBuilderService;
impl<M, Msg> Service<(M, Msg)> for CompletionRequestBuilderService
where
M: CompletionModel + Send + 'static,
Msg: Into<Message> + Send + 'static,
{
type Response = CompletionRequestBuilder<M>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future {
Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) })
}
}

View File

@ -224,6 +224,7 @@ where
/// A completion model as a Tower service. /// A completion model as a Tower service.
/// ///
/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service. /// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service.
#[derive(Clone)]
pub struct CompletionService<M> { pub struct CompletionService<M> {
/// The model itself. /// The model itself.
model: M, model: M,

View File

@ -1,68 +1,54 @@
use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; use serde::Deserialize;
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service}; use tower::{Layer, Service};
use crate::{ use crate::{
completion::{CompletionModel, CompletionRequest, CompletionResponse}, completion::{CompletionRequest, CompletionResponse},
extractor::{ExtractionError, Extractor},
message::{AssistantContent, Text}, message::{AssistantContent, Text},
}; };
pub struct ExtractorLayer<M, T> use super::ServiceError;
where
M: CompletionModel, pub struct ExtractorLayer<T> {
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, _t: PhantomData<T>,
{
ext: Arc<Extractor<M, T>>,
} }
impl<M, T> ExtractorLayer<M, T> impl<T> ExtractorLayer<T>
where where
M: CompletionModel, T: for<'a> Deserialize<'a>,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
{ {
pub fn new(ext: Extractor<M, T>) -> Self { pub fn new() -> Self {
Self { ext: Arc::new(ext) } Self { _t: PhantomData }
} }
} }
impl<S, M, T> Layer<S> for ExtractorLayer<M, T> impl<S, T> Layer<S> for ExtractorLayer<T>
where where
M: CompletionModel + 'static, T: for<'a> Deserialize<'a>,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
{ {
type Service = ExtractorLayerService<S, M, T>; type Service = ExtractorLayerService<S, T>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
ExtractorLayerService { ExtractorLayerService { inner, _t: self._t }
inner,
ext: Arc::clone(&self.ext),
}
} }
} }
pub struct ExtractorLayerService<S, M, T> pub struct ExtractorLayerService<S, T> {
where
M: CompletionModel + 'static,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
{
inner: S, inner: S,
ext: Arc<Extractor<M, T>>, _t: PhantomData<T>,
} }
impl<S, M, T> Service<CompletionRequest> for ExtractorLayerService<S, M, T> impl<S, F, T> Service<CompletionRequest> for ExtractorLayerService<S, T>
where where
S: Service<CompletionRequest, Response = CompletionResponse<M::Response>> S: Service<CompletionRequest, Response = CompletionResponse<F>, Error = ServiceError>
+ Clone + Clone
+ Send + Send
+ 'static, + 'static,
S::Future: Send, S::Future: Send,
M: CompletionModel + 'static, F: 'static,
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, T: for<'a> Deserialize<'a> + 'static,
{ {
type Response = T; type Response = T;
type Error = ExtractionError; type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready( fn poll_ready(
@ -73,63 +59,18 @@ where
} }
fn call(&mut self, req: CompletionRequest) -> Self::Future { fn call(&mut self, req: CompletionRequest) -> Self::Future {
let ext = self.ext.clone();
let mut inner = self.inner.clone(); let mut inner = self.inner.clone();
Box::pin(async move { Box::pin(async move {
let Ok(res) = inner.call(req).await else { let res = inner.call(req).await?;
todo!("Properly handle error");
};
let AssistantContent::Text(Text { text }) = res.choice.first() else { let AssistantContent::Text(Text { text }) = res.choice.first() else {
todo!("Handle errors properly"); todo!("Handle errors properly");
}; };
ext.extract(&text).await let obj = serde_json::from_str::<T>(&text)?;
Ok(obj)
}) })
} }
} }
pub struct ExtractorService<M, T>
where
M: CompletionModel + 'static,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
{
ext: Arc<Extractor<M, T>>,
}
impl<M, T> ExtractorService<M, T>
where
M: CompletionModel + 'static,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
{
pub fn new(ext: Extractor<M, T>) -> Self {
Self { ext: Arc::new(ext) }
}
}
impl<M, T> Service<CompletionRequest> for ExtractorService<M, T>
where
M: CompletionModel,
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
{
type Response = T;
type Error = ExtractionError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let ext = self.ext.clone();
let Some(req) = req.prompt.rag_text() else {
todo!("Handle error properly");
};
Box::pin(async move { ext.extract(&req).await })
}
}

View File

@ -1,10 +1,15 @@
use thiserror::Error; use thiserror::Error;
use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError}; use crate::{
completion::CompletionError, extractor::ExtractionError, tool::ToolSetError,
vector_store::VectorStoreError,
};
pub mod build_completions;
pub mod completion; pub mod completion;
pub mod components; pub mod components;
pub mod extractor; pub mod extractor;
pub mod parallel;
pub mod rag; pub mod rag;
pub mod tools; pub mod tools;
@ -16,4 +21,17 @@ pub enum ServiceError {
CompletionError(#[from] CompletionError), CompletionError(#[from] CompletionError),
#[error("{0}")] #[error("{0}")]
ToolSetError(#[from] ToolSetError), ToolSetError(#[from] ToolSetError),
#[error("{0}")]
VectorStoreError(#[from] VectorStoreError),
#[error("Value required but was null: {0}")]
RequiredOptionNotFound(String),
#[error("{0}")]
Json(#[from] serde_json::Error),
}
impl ServiceError {
pub fn required_option_not_exists<S: Into<String>>(val: S) -> Self {
let val: String = val.into();
Self::RequiredOptionNotFound(val)
}
} }

View File

@ -0,0 +1,88 @@
use std::{future::Future, pin::Pin};
use tower::Service;
use crate::completion::CompletionRequest;
use super::ServiceError;
pub struct Stackable<A, B> {
pub inner: A,
pub outer: B,
}
impl<A, B> Stackable<A, B> {
pub fn new(inner: A, outer: B) -> Self {
Self { inner, outer }
}
pub fn take_values(self) -> (A, B) {
(self.inner, self.outer)
}
}
#[derive(Clone)]
pub struct ParallelService<S, T> {
first_service: S,
second_service: T,
}
impl<S, T> ParallelService<S, T>
where
S: Service<CompletionRequest>,
T: Service<CompletionRequest>,
{
pub fn new(first_service: S, second_service: T) -> Self {
Self {
first_service,
second_service,
}
}
}
impl<S, T> Service<CompletionRequest> for ParallelService<S, T>
where
S: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
S::Future: Send,
S::Response: Send + 'static,
T: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
T::Future: Send,
T::Response: Send + 'static,
{
type Response = Stackable<S::Response, T::Response>;
type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: CompletionRequest) -> Self::Future {
let mut first = self.first_service.clone();
let mut second = self.second_service.clone();
Box::pin(async move {
let res1 = first.call(req.clone()).await?;
let res2 = second.call(req.clone()).await?;
let stackable = Stackable::new(res1, res2);
Ok(stackable)
})
}
}
#[macro_export]
macro_rules! parallel_service {
($service1:tt, $service2:tt) => {
$crate::pipeline::parallel::ParallelService::new($service1, $service2)
};
($op1:tt $(, $ops:tt)*) => {
$crate::pipeline::parallel::ParallelService::new(
$service1,
$crate::parallel_op!($($ops),*)
)
};
}

View File

@ -9,10 +9,9 @@ use std::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tower::Service; use tower::Service;
use crate::{ use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex};
completion::CompletionRequest,
vector_store::{VectorStoreError, VectorStoreIndex}, use super::ServiceError;
};
pub struct RagService<V, T> { pub struct RagService<V, T> {
vector_index: Arc<V>, vector_index: Arc<V>,
@ -39,7 +38,7 @@ where
T: Serialize + for<'a> Deserialize<'a> + Send, T: Serialize + for<'a> Deserialize<'a> + Send,
{ {
type Response = RagResult<T>; type Response = RagResult<T>;
type Error = VectorStoreError; type Error = ServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -49,11 +48,16 @@ where
fn call(&mut self, req: CompletionRequest) -> Self::Future { fn call(&mut self, req: CompletionRequest) -> Self::Future {
let vector_index = self.vector_index.clone(); let vector_index = self.vector_index.clone();
let num_results = self.num_results; let num_results = self.num_results;
let Some(prompt) = req.prompt.rag_text() else {
todo!("Handle error properly");
};
Box::pin(async move { vector_index.top_n(&prompt, num_results).await }) Box::pin(async move {
let Some(prompt) = req.prompt.rag_text() else {
return Err(ServiceError::required_option_not_exists("rag_text"));
};
let res = vector_index.top_n(&prompt, num_results).await?;
Ok(res)
})
} }
} }

View File

@ -1,7 +1,6 @@
use crate::{ use crate::{
completion::{CompletionRequest, CompletionResponse}, completion::{CompletionRequest, CompletionResponse},
message::{AssistantContent, Message, ToolResultContent, UserContent}, message::{AssistantContent, Message, ToolResultContent},
providers::{self},
tool::{ToolSet, ToolSetError}, tool::{ToolSet, ToolSetError},
OneOrMany, OneOrMany,
}; };
@ -83,49 +82,3 @@ where
}) })
} }
} }
pub struct ToolService {
tools: Arc<ToolSet>,
}
type OpenAIResponse = CompletionResponse<providers::openai::CompletionResponse>;
impl Service<OpenAIResponse> for ToolService {
type Response = Message;
type Error = ToolSetError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: OpenAIResponse) -> Self::Future {
let tools = self.tools.clone();
Box::pin(async move {
let crate::message::AssistantContent::ToolCall(tool_call) = req.choice.first() else {
unimplemented!("handle error");
};
let Ok(res) = tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await
else {
todo!("Implement proper error handling");
};
Ok(Message::User {
content: OneOrMany::one(UserContent::tool_result(
tool_call.id,
OneOrMany::one(ToolResultContent::text(res)),
)),
})
})
}
}