mirror of https://github.com/0xplaygrounds/rig
Merge a42ea43763
into 88edab96b5
This commit is contained in:
commit
1bc59da156
|
@ -8612,6 +8612,7 @@ dependencies = [
|
|||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tower 0.5.2",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"worker",
|
||||
|
|
|
@ -9,6 +9,7 @@ members = [
|
|||
"rig-qdrant",
|
||||
"rig-core/rig-core-derive",
|
||||
"rig-sqlite",
|
||||
"rig-eternalai", "rig-fastembed",
|
||||
"rig-eternalai",
|
||||
"rig-fastembed",
|
||||
"rig-surrealdb",
|
||||
]
|
||||
|
|
|
@ -39,6 +39,7 @@ bytes = "1.9.0"
|
|||
async-stream = "0.3.6"
|
||||
mime_guess = { version = "2.0.5" }
|
||||
base64 = { version = "0.22.1" }
|
||||
tower = "0.5.2"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
use rig::{
|
||||
completion::{CompletionRequestBuilder, ToolDefinition},
|
||||
middlewares::{
|
||||
completion::{CompletionLayer, CompletionService},
|
||||
tools::ToolLayer,
|
||||
},
|
||||
providers::openai::Client,
|
||||
tool::{Tool, ToolSet},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::Service;
|
||||
use tower::ServiceBuilder;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let client = Client::from_env();
|
||||
let model = client.completion_model("gpt-4o");
|
||||
|
||||
let comp_layer = CompletionLayer::builder(model.clone()).build();
|
||||
let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add]));
|
||||
let service = CompletionService::new(model.clone());
|
||||
|
||||
let mut service = ServiceBuilder::new()
|
||||
.layer(comp_layer)
|
||||
.layer(tool_layer)
|
||||
.service(service);
|
||||
|
||||
let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build();
|
||||
|
||||
let res = service.call(comp_request).await.unwrap();
|
||||
|
||||
println!("{res:?}");
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct Add;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
impl Tool for Add {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
type Error = MathError;
|
||||
type Args = OperationArgs;
|
||||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"name": "add",
|
||||
"description": "Add x and y together",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number to add"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number to add"
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.expect("Tool Definition")
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
let result = args.x + args.y;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
|
@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync {
|
|||
}
|
||||
|
||||
/// Struct representing a general completion request that can be sent to a completion model provider.
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionRequest {
|
||||
/// The prompt to be sent to the completion model provider
|
||||
pub prompt: Message,
|
||||
|
|
|
@ -91,6 +91,7 @@ pub mod extractor;
|
|||
pub mod image_generation;
|
||||
pub(crate) mod json_utils;
|
||||
pub mod loaders;
|
||||
pub mod middlewares;
|
||||
pub mod one_or_many;
|
||||
pub mod pipeline;
|
||||
pub mod providers;
|
||||
|
|
|
@ -0,0 +1,532 @@
|
|||
use std::{future::Future, pin::Pin, task::Poll};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition,
|
||||
},
|
||||
message::Message,
|
||||
};
|
||||
|
||||
/// A Tower layer to finish building your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct FinishBuilding;
|
||||
|
||||
impl<S> Layer<S> for FinishBuilding {
|
||||
type Service = FinishBuildingService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
FinishBuildingService { inner }
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to finish building your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly. Use [`FinishBuilding`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct FinishBuildingService<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for FinishBuildingService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequest;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.build())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add documents to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct DocumentsLayer {
|
||||
documents: Vec<Document>,
|
||||
}
|
||||
|
||||
impl DocumentsLayer {
|
||||
pub fn new(documents: Vec<Document>) -> Self {
|
||||
Self { documents }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for DocumentsLayer {
|
||||
type Service = DocumentsLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
DocumentsLayerService {
|
||||
inner,
|
||||
documents: self.documents.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add documents to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`DocumentsLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct DocumentsLayerService<S> {
|
||||
inner: S,
|
||||
documents: Vec<Document>,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for DocumentsLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let documents = self.documents.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.documents(documents))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct TemperatureLayer {
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
impl TemperatureLayer {
|
||||
pub fn new(temperature: f64) -> Self {
|
||||
Self { temperature }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for TemperatureLayer {
|
||||
type Service = TemperatureLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
TemperatureLayerService {
|
||||
inner,
|
||||
temperature: self.temperature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add a temperature value to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`TemperatureLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct TemperatureLayerService<S> {
|
||||
inner: S,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for TemperatureLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let temperature = self.temperature;
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.temperature(temperature))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add tools to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct ToolsLayer {
|
||||
tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
impl ToolsLayer {
|
||||
pub fn new(tools: Vec<ToolDefinition>) -> Self {
|
||||
Self { tools }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for ToolsLayer {
|
||||
type Service = ToolsLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ToolsLayerService {
|
||||
inner,
|
||||
tools: self.tools.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add tools to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`ToolsLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct ToolsLayerService<S> {
|
||||
inner: S,
|
||||
tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for ToolsLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let tools = self.tools.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.tools(tools))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct PreambleLayer {
|
||||
preamble: String,
|
||||
}
|
||||
|
||||
impl PreambleLayer {
|
||||
pub fn new(preamble: String) -> Self {
|
||||
Self { preamble }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for PreambleLayer {
|
||||
type Service = PreambleLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
PreambleLayerService {
|
||||
inner,
|
||||
preamble: self.preamble.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`PreambleLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct PreambleLayerService<S> {
|
||||
inner: S,
|
||||
preamble: String,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for PreambleLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let preamble = self.preamble.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.preamble(preamble))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct AdditionalParamsLayer {
|
||||
additional_params: serde_json::Value,
|
||||
}
|
||||
|
||||
impl AdditionalParamsLayer {
|
||||
pub fn new(additional_params: serde_json::Value) -> Self {
|
||||
Self { additional_params }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for AdditionalParamsLayer {
|
||||
type Service = AdditionalParamsLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
AdditionalParamsLayerService {
|
||||
inner,
|
||||
additional_params: self.additional_params.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
|
||||
/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct AdditionalParamsLayerService<S> {
|
||||
inner: S,
|
||||
additional_params: serde_json::Value,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for AdditionalParamsLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let additional_params = self.additional_params.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.additional_params(additional_params))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct MaxTokensLayer {
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
impl MaxTokensLayer {
|
||||
pub fn new(max_tokens: u64) -> Self {
|
||||
Self { max_tokens }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for MaxTokensLayer {
|
||||
type Service = MaxTokensLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
MaxTokensLayerService {
|
||||
inner,
|
||||
max_tokens: self.max_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`MaxTokensLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct MaxTokensLayerService<S> {
|
||||
inner: S,
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for MaxTokensLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let max_tokens = self.max_tokens;
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.max_tokens(max_tokens))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower layer to add a chat history to your `CompletionRequestBuilder`.
|
||||
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct ChatHistoryLayer {
|
||||
chat_history: Vec<Message>,
|
||||
}
|
||||
|
||||
impl ChatHistoryLayer {
|
||||
pub fn new(chat_history: Vec<Message>) -> Self {
|
||||
Self { chat_history }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for ChatHistoryLayer {
|
||||
type Service = ChatHistoryLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ChatHistoryLayerService {
|
||||
inner,
|
||||
chat_history: self.chat_history.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tower service to add a chat history to your `CompletionRequestBuilder`.
|
||||
/// Not intended to be used directly - use [`ChatHistoryLayer`] instead.
|
||||
///
|
||||
/// See [`CompletionRequestBuilderService`] for usage.
|
||||
pub struct ChatHistoryLayerService<S> {
|
||||
inner: S,
|
||||
chat_history: Vec<Message>,
|
||||
}
|
||||
|
||||
impl<M, Msg, S> Service<(M, Msg)> for ChatHistoryLayerService<S>
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let chat_history = self.chat_history.clone();
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
Ok(res.messages(chat_history))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A `tower` service for building a completion request.
|
||||
/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top).
|
||||
///
|
||||
/// Usage:
|
||||
/// ```rust
|
||||
/// let agent = rig::providers::openai::Client::from_env();
|
||||
/// let model = agent.completion_model("gpt-4o");
|
||||
///
|
||||
/// let service = tower::ServiceBuilder::new()
|
||||
/// .layer(FinishBuilding)
|
||||
/// .layer(TemperatureLayer::new(0.0))
|
||||
/// .layer(PreambleLayer::new("You are a helpful assistant"))
|
||||
/// .service(CompletionRequestBuilderService);
|
||||
///
|
||||
/// let request = service.call((model, "Hello world!".to_string())).await.unwrap();
|
||||
///
|
||||
/// let res = request.send().await.unwrap();
|
||||
///
|
||||
/// println!("{res:?}");
|
||||
/// ```
|
||||
pub struct CompletionRequestBuilderService;
|
||||
|
||||
impl<M, Msg> Service<(M, Msg)> for CompletionRequestBuilderService
|
||||
where
|
||||
M: CompletionModel + Send + 'static,
|
||||
Msg: Into<Message> + Send + 'static,
|
||||
{
|
||||
type Response = CompletionRequestBuilder<M>;
|
||||
type Error = ();
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future {
|
||||
Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) })
|
||||
}
|
||||
}
|
|
@ -0,0 +1,262 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
completion::{
|
||||
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder,
|
||||
CompletionResponse, Document, ToolDefinition,
|
||||
},
|
||||
message::{Message, ToolResultContent, UserContent},
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionLayer<M> {
|
||||
model: M,
|
||||
preamble: Option<String>,
|
||||
documents: Vec<Document>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
temperature: Option<f64>,
|
||||
max_tokens: Option<u64>,
|
||||
additional_params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl<M> CompletionLayer<M>
|
||||
where
|
||||
M: CompletionModel,
|
||||
{
|
||||
pub fn builder(model: M) -> CompletionLayerBuilder<M> {
|
||||
CompletionLayerBuilder::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CompletionLayerBuilder<M> {
|
||||
model: M,
|
||||
preamble: Option<String>,
|
||||
documents: Vec<Document>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
temperature: Option<f64>,
|
||||
max_tokens: Option<u64>,
|
||||
additional_params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl<M> CompletionLayerBuilder<M>
|
||||
where
|
||||
M: CompletionModel,
|
||||
{
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
preamble: None,
|
||||
documents: vec![],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
additional_params: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preamble(mut self, preamble: String) -> Self {
|
||||
self.preamble = Some(preamble);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn preamble_opt(mut self, preamble: Option<String>) -> Self {
|
||||
self.preamble = preamble;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn documents(mut self, documents: Vec<Document>) -> Self {
|
||||
self.documents = documents;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
|
||||
self.tools = tools;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
|
||||
self.temperature = temperature;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn additional_params_opt(mut self, params: Option<serde_json::Value>) -> Self {
|
||||
self.additional_params = params;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> CompletionLayer<M> {
|
||||
CompletionLayer {
|
||||
model: self.model,
|
||||
preamble: self.preamble,
|
||||
|
||||
documents: self.documents,
|
||||
tools: self.tools,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, S> Layer<S> for CompletionLayer<M>
|
||||
where
|
||||
M: CompletionModel,
|
||||
{
|
||||
type Service = CompletionLayerService<M, S>;
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
CompletionLayerService {
|
||||
inner,
|
||||
model: self.model.clone(),
|
||||
preamble: self.preamble.clone(),
|
||||
|
||||
documents: self.documents.clone(),
|
||||
tools: self.tools.clone(),
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
additional_params: self.additional_params.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionLayerService<M, S> {
|
||||
inner: S,
|
||||
model: M,
|
||||
preamble: Option<String>,
|
||||
documents: Vec<Document>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
temperature: Option<f64>,
|
||||
max_tokens: Option<u64>,
|
||||
additional_params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl<M, S> Service<CompletionRequest> for CompletionLayerService<M, S>
|
||||
where
|
||||
M: CompletionModel + 'static,
|
||||
S: Service<CompletionRequest, Response = (Vec<Message>, String, ToolResultContent)>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = CompletionResponse<M::Response>;
|
||||
type Error = CompletionError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let model = self.model.clone();
|
||||
let preamble = self.preamble.clone();
|
||||
let documents = self.documents.clone();
|
||||
let temperature = self.temperature;
|
||||
let tools = self.tools.clone();
|
||||
let max_tokens = self.max_tokens;
|
||||
let additional_params = self.additional_params.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let Ok((messages, id, tool_content)) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
let tool_result_message = Message::User {
|
||||
content: OneOrMany::one(UserContent::tool_result(id, OneOrMany::one(tool_content))),
|
||||
};
|
||||
|
||||
let mut req = CompletionRequestBuilder::new(model.clone(), tool_result_message)
|
||||
.documents(documents.clone())
|
||||
.tools(tools.clone())
|
||||
.messages(messages)
|
||||
.temperature_opt(temperature)
|
||||
.max_tokens_opt(max_tokens)
|
||||
.additional_params_opt(additional_params.clone());
|
||||
|
||||
if let Some(preamble) = preamble.clone() {
|
||||
req = req.preamble(preamble);
|
||||
}
|
||||
|
||||
let req = req.build();
|
||||
|
||||
model.completion(req).await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A completion model as a Tower service.
|
||||
///
|
||||
/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service.
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionService<M> {
|
||||
/// The model itself.
|
||||
model: M,
|
||||
}
|
||||
|
||||
impl<M> CompletionService<M>
|
||||
where
|
||||
M: CompletionModel,
|
||||
{
|
||||
pub fn new(model: M) -> Self {
|
||||
Self { model }
|
||||
}
|
||||
}
|
||||
|
||||
impl<M> Service<CompletionRequest> for CompletionService<M>
|
||||
where
|
||||
M: CompletionModel + 'static,
|
||||
{
|
||||
type Response = CompletionResponse<M::Response>;
|
||||
type Error = CompletionError;
|
||||
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let model = self.model.clone();
|
||||
|
||||
Box::pin(async move { model.completion(req).await })
|
||||
}
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, task::Poll};
|
||||
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionRequest, CompletionResponse},
|
||||
message::{AssistantContent, Text},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwaitApprovalLayer<T> {
|
||||
integration: T,
|
||||
}
|
||||
|
||||
impl<T> AwaitApprovalLayer<T>
|
||||
where
|
||||
T: HumanInTheLoop + Clone,
|
||||
{
|
||||
pub fn new(integration: T) -> Self {
|
||||
Self { integration }
|
||||
}
|
||||
|
||||
pub fn with_predicate<R, D>(self, predicate: R) -> AwaitApprovalLayerWithPredicate<D, R, T>
|
||||
where
|
||||
D: Debug,
|
||||
R: Fn() -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
|
||||
{
|
||||
AwaitApprovalLayerWithPredicate {
|
||||
integration: self.integration,
|
||||
predicate,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Layer<S> for AwaitApprovalLayer<T>
|
||||
where
|
||||
T: HumanInTheLoop + Clone,
|
||||
{
|
||||
type Service = AwaitApprovalLayerService<S, T>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
AwaitApprovalLayerService::new(inner, self.integration.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwaitApprovalLayerService<S, T> {
|
||||
inner: S,
|
||||
integration: T,
|
||||
}
|
||||
|
||||
impl<S, T> AwaitApprovalLayerService<S, T>
|
||||
where
|
||||
T: HumanInTheLoop,
|
||||
{
|
||||
pub fn new(inner: S, integration: T) -> Self {
|
||||
Self { inner, integration }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T, Response> Service<CompletionRequest> for AwaitApprovalLayerService<S, T>
|
||||
where
|
||||
S: Service<CompletionRequest, Response = CompletionResponse<Response>> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
Response: Clone + 'static + Send,
|
||||
T: HumanInTheLoop + Clone + Send + 'static,
|
||||
{
|
||||
type Response = CompletionResponse<Response>;
|
||||
type Error = bool;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let await_approval_loop = self.integration.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
let AssistantContent::Text(Text { text }) = res.choice.first() else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
if await_approval_loop.send_message(&text).await.is_err() {
|
||||
todo!("Handle error properly");
|
||||
}
|
||||
|
||||
let Ok(bool) = await_approval_loop.await_approval().await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
if bool {
|
||||
Ok(res)
|
||||
} else {
|
||||
todo!("Handle error properly - we should abort the pipeline here if the user wants to abort");
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AwaitApprovalLayerWithPredicate<P, R, T> {
|
||||
integration: T,
|
||||
predicate: R,
|
||||
_t: PhantomData<P>,
|
||||
}
|
||||
|
||||
impl<D, R, S, T> Layer<S> for AwaitApprovalLayerWithPredicate<D, R, T>
|
||||
where
|
||||
T: HumanInTheLoop + Clone,
|
||||
D: Debug,
|
||||
R: Fn(&D) -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
|
||||
{
|
||||
type Service = AwaitApprovalLayerServiceWithPredicate<D, R, S, T>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
let predicate = self.predicate.clone();
|
||||
AwaitApprovalLayerServiceWithPredicate::new(inner, self.integration.clone(), predicate)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AwaitApprovalLayerServiceWithPredicate<D, R, S, T> {
|
||||
inner: S,
|
||||
integration: T,
|
||||
predicate: R,
|
||||
_t: PhantomData<D>,
|
||||
}
|
||||
|
||||
impl<D, R, S, T> AwaitApprovalLayerServiceWithPredicate<D, R, S, T>
|
||||
where
|
||||
T: HumanInTheLoop,
|
||||
R: Fn(&D) -> Pin<Box<dyn Future<Output = bool> + Send>> + Clone + Send + 'static,
|
||||
D: Debug,
|
||||
{
|
||||
pub fn new(inner: S, integration: T, predicate: R) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
integration,
|
||||
predicate,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, S, T, Response, R> Service<CompletionRequest>
|
||||
for AwaitApprovalLayerServiceWithPredicate<D, R, S, T>
|
||||
where
|
||||
R: Fn(&CompletionResponse<Response>) -> Pin<Box<dyn Future<Output = bool> + Send>>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S: Service<CompletionRequest, Response = CompletionResponse<Response>> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
Response: Clone + 'static + Send,
|
||||
T: HumanInTheLoop + Clone + Send + 'static,
|
||||
{
|
||||
type Response = CompletionResponse<Response>;
|
||||
type Error = bool;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let await_approval_loop = self.integration.clone();
|
||||
let predicate = self.predicate.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
if predicate(&res).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
let AssistantContent::Text(Text { text }) = res.choice.first() else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
if await_approval_loop.send_message(&text).await.is_err() {
|
||||
todo!("Handle error properly");
|
||||
}
|
||||
|
||||
let Ok(bool) = await_approval_loop.await_approval().await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
if bool {
|
||||
Ok(res)
|
||||
} else {
|
||||
todo!("Handle error properly - we should abort the pipeline here if the user wants to abort");
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HumanInTheLoop {
|
||||
fn send_message(
|
||||
&self,
|
||||
res: &str,
|
||||
) -> impl Future<Output = Result<(), Box<dyn std::error::Error>>> + Send;
|
||||
fn await_approval(
|
||||
&self,
|
||||
) -> impl Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send;
|
||||
}
|
||||
|
||||
pub struct Stdout;
|
||||
|
||||
impl HumanInTheLoop for Stdout {
|
||||
async fn send_message(&self, res: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
print!(
|
||||
"Current result: {res}
|
||||
|
||||
Would you like to approve this step? [Y/n]"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn await_approval(&self) -> Result<bool, Box<dyn std::error::Error>> {
|
||||
let mut string = String::new();
|
||||
|
||||
loop {
|
||||
std::io::stdin().read_line(&mut string).unwrap();
|
||||
|
||||
match string.to_lowercase().trim() {
|
||||
"y" | "yes" => break Ok(true),
|
||||
"n" | "no" => break Ok(false),
|
||||
_ => println!("Please respond with 'y' or 'n'."),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
use serde::Deserialize;
|
||||
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
completion::{CompletionRequest, CompletionResponse},
|
||||
message::{AssistantContent, Text},
|
||||
};
|
||||
|
||||
use super::ServiceError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExtractorLayer<T> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> ExtractorLayer<T>
|
||||
where
|
||||
T: for<'a> Deserialize<'a>,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for ExtractorLayer<T>
|
||||
where
|
||||
T: for<'a> Deserialize<'a>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Layer<S> for ExtractorLayer<T>
|
||||
where
|
||||
T: for<'a> Deserialize<'a>,
|
||||
{
|
||||
type Service = ExtractorLayerService<S, T>;
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ExtractorLayerService { inner, _t: self._t }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExtractorLayerService<S, T> {
|
||||
inner: S,
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<S, F, T> Service<CompletionRequest> for ExtractorLayerService<S, T>
|
||||
where
|
||||
S: Service<CompletionRequest, Response = CompletionResponse<F>, Error = ServiceError>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S::Future: Send,
|
||||
F: 'static,
|
||||
T: for<'a> Deserialize<'a> + 'static,
|
||||
{
|
||||
type Response = T;
|
||||
type Error = ServiceError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let res = inner.call(req).await?;
|
||||
|
||||
let AssistantContent::Text(Text { text }) = res.choice.first() else {
|
||||
todo!("Handle errors properly");
|
||||
};
|
||||
|
||||
let obj = serde_json::from_str::<T>(&text)?;
|
||||
|
||||
Ok(obj)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
completion::CompletionError, extractor::ExtractionError, tool::ToolSetError,
|
||||
vector_store::VectorStoreError,
|
||||
};
|
||||
|
||||
pub mod build_completions;
|
||||
pub mod completion;
|
||||
pub mod components;
|
||||
pub mod extractor;
|
||||
pub mod parallel;
|
||||
pub mod rag;
|
||||
pub mod tools;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ServiceError {
|
||||
#[error("{0}")]
|
||||
ExtractionError(#[from] ExtractionError),
|
||||
#[error("{0}")]
|
||||
CompletionError(#[from] CompletionError),
|
||||
#[error("{0}")]
|
||||
ToolSetError(#[from] ToolSetError),
|
||||
#[error("{0}")]
|
||||
VectorStoreError(#[from] VectorStoreError),
|
||||
#[error("Value required but was null: {0}")]
|
||||
RequiredOptionNotFound(String),
|
||||
#[error("{0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("Custom error: {0}")]
|
||||
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl ServiceError {
|
||||
pub fn required_option_not_exists<S: Into<String>>(val: S) -> Self {
|
||||
let val: String = val.into();
|
||||
Self::RequiredOptionNotFound(val)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use tower::Service;
|
||||
|
||||
use crate::completion::CompletionRequest;
|
||||
|
||||
use super::ServiceError;
|
||||
|
||||
pub struct Stackable<A, B> {
|
||||
pub inner: A,
|
||||
pub outer: B,
|
||||
}
|
||||
|
||||
impl<A, B> Stackable<A, B> {
|
||||
pub fn new(inner: A, outer: B) -> Self {
|
||||
Self { inner, outer }
|
||||
}
|
||||
|
||||
pub fn take_values(self) -> (A, B) {
|
||||
(self.inner, self.outer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ParallelService<S, T> {
|
||||
first_service: S,
|
||||
second_service: T,
|
||||
}
|
||||
|
||||
impl<S, T> ParallelService<S, T>
|
||||
where
|
||||
S: Service<CompletionRequest>,
|
||||
T: Service<CompletionRequest>,
|
||||
{
|
||||
pub fn new(first_service: S, second_service: T) -> Self {
|
||||
Self {
|
||||
first_service,
|
||||
second_service,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Service<CompletionRequest> for ParallelService<S, T>
|
||||
where
|
||||
S: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
S::Response: Send + 'static,
|
||||
T: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
|
||||
T::Future: Send,
|
||||
T::Response: Send + 'static,
|
||||
{
|
||||
type Response = Stackable<S::Response, T::Response>;
|
||||
type Error = ServiceError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
std::task::Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut first = self.first_service.clone();
|
||||
let mut second = self.second_service.clone();
|
||||
Box::pin(async move {
|
||||
let res1 = first.call(req.clone()).await?;
|
||||
let res2 = second.call(req.clone()).await?;
|
||||
|
||||
let stackable = Stackable::new(res1, res2);
|
||||
|
||||
Ok(stackable)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! parallel_service {
|
||||
($service1:tt, $service2:tt) => {
|
||||
$crate::pipeline::parallel::ParallelService::new($service1, $service2)
|
||||
};
|
||||
($op1:tt $(, $ops:tt)*) => {
|
||||
$crate::pipeline::parallel::ParallelService::new(
|
||||
$service1,
|
||||
$crate::parallel_op!($($ops),*)
|
||||
)
|
||||
};
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::Service;
|
||||
|
||||
use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex};
|
||||
|
||||
use super::ServiceError;
|
||||
|
||||
pub struct RagService<V, T> {
|
||||
vector_index: Arc<V>,
|
||||
num_results: usize,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<V, T> RagService<V, T>
|
||||
where
|
||||
V: VectorStoreIndex,
|
||||
{
|
||||
pub fn new(vector_index: V, num_results: usize) -> Self {
|
||||
Self {
|
||||
vector_index: Arc::new(vector_index),
|
||||
num_results,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V, T> Service<CompletionRequest> for RagService<V, T>
|
||||
where
|
||||
V: VectorStoreIndex + 'static,
|
||||
T: Serialize + for<'a> Deserialize<'a> + Send,
|
||||
{
|
||||
type Response = RagResult<T>;
|
||||
type Error = ServiceError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let vector_index = self.vector_index.clone();
|
||||
let num_results = self.num_results;
|
||||
|
||||
Box::pin(async move {
|
||||
let Some(prompt) = req.prompt.rag_text() else {
|
||||
return Err(ServiceError::required_option_not_exists("rag_text"));
|
||||
};
|
||||
|
||||
let res = vector_index.top_n(&prompt, num_results).await?;
|
||||
|
||||
Ok(res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub type RagResult<T> = Vec<(f64, String, T)>;
|
|
@ -0,0 +1,86 @@
|
|||
use crate::{
|
||||
completion::{CompletionRequest, CompletionResponse},
|
||||
message::{AssistantContent, Message, ToolResultContent},
|
||||
tool::{ToolSet, ToolSetError},
|
||||
OneOrMany,
|
||||
};
|
||||
use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
|
||||
|
||||
use tower::{Layer, Service};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolLayer {
|
||||
tools: Arc<ToolSet>,
|
||||
}
|
||||
|
||||
impl ToolLayer {
|
||||
pub fn new(tools: ToolSet) -> Self {
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for ToolLayer {
|
||||
type Service = ToolLayerService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ToolLayerService {
|
||||
inner,
|
||||
tools: Arc::clone(&self.tools),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolLayerService<S> {
|
||||
inner: S,
|
||||
tools: Arc<ToolSet>,
|
||||
}
|
||||
|
||||
impl<S, T> Service<CompletionRequest> for ToolLayerService<S>
|
||||
where
|
||||
S: Service<CompletionRequest, Response = CompletionResponse<T>> + Clone + Send + 'static,
|
||||
T: Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = (Vec<Message>, String, ToolResultContent);
|
||||
type Error = ToolSetError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
let tools = self.tools.clone();
|
||||
let mut messages = req.chat_history.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let Ok(res) = inner.call(req).await else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
let AssistantContent::ToolCall(tool_call) = res.choice.first() else {
|
||||
todo!("Handle error properly");
|
||||
};
|
||||
|
||||
messages.push(Message::Assistant {
|
||||
content: OneOrMany::one(AssistantContent::ToolCall(tool_call.clone())),
|
||||
});
|
||||
|
||||
let Ok(res) = tools
|
||||
.call(
|
||||
&tool_call.function.name,
|
||||
tool_call.function.arguments.to_string(),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
todo!("Implement proper error handling");
|
||||
};
|
||||
|
||||
Ok((messages, tool_call.id, ToolResultContent::text(res)))
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue