fix: satisfy ci

This commit is contained in:
Joshua Mo 2025-04-09 23:21:43 +01:00
parent 0ba10f0e3e
commit bf652f47e8
5 changed files with 27 additions and 6 deletions

View File

@ -1,11 +1,14 @@
use rig::{
completion::CompletionRequestBuilder,
completion::{CompletionRequestBuilder, ToolDefinition},
middlewares::{
completion::{CompletionLayer, CompletionService},
tools::ToolLayer,
},
providers::openai::Client,
tool::{Tool, ToolSet},
};
use serde::{Deserialize, Serialize};
use tower::Service;
use tower::ServiceBuilder;
#[tokio::main]
@ -13,11 +16,11 @@ async fn main() {
let client = Client::from_env();
let model = client.completion_model("gpt-4o");
let comp_layer = CompletionLayer::builder(model).build();
let tool_layer = ToolLayer::new(vec![Add]);
let service = CompletionService::new(model);
let comp_layer = CompletionLayer::builder(model.clone()).build();
let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add]));
let service = CompletionService::new(model.clone());
let service = ServiceBuilder::new()
let mut service = ServiceBuilder::new()
.layer(comp_layer)
.layer(tool_layer)
.service(service);
@ -32,6 +35,16 @@ async fn main() {
#[derive(Deserialize, Serialize)]
struct Add;
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
impl Tool for Add {
const NAME: &'static str = "add";
@ -40,7 +53,7 @@ impl Tool for Add {
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
serde_json::from_value(serde_json::json!({
"name": "add",
"description": "Add x and y together",
"parameters": {

View File

@ -15,6 +15,7 @@ use crate::{
OneOrMany,
};
#[derive(Clone)]
pub struct CompletionLayer<M> {
model: M,
preamble: Option<String>,
@ -155,6 +156,7 @@ where
}
}
#[derive(Clone)]
pub struct CompletionLayerService<M, S> {
inner: S,
model: M,

View File

@ -7,6 +7,7 @@ use crate::{
message::{AssistantContent, Text},
};
#[derive(Clone)]
pub struct AwaitApprovalLayer<T> {
integration: T,
}
@ -43,6 +44,7 @@ where
}
}
#[derive(Clone)]
pub struct AwaitApprovalLayerService<S, T> {
inner: S,
integration: T,

View File

@ -9,6 +9,7 @@ use crate::{
use super::ServiceError;
#[derive(Clone)]
pub struct ExtractorLayer<T> {
_t: PhantomData<T>,
}
@ -41,6 +42,7 @@ where
}
}
#[derive(Clone)]
pub struct ExtractorLayerService<S, T> {
inner: S,
_t: PhantomData<T>,

View File

@ -8,6 +8,7 @@ use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct ToolLayer {
tools: Arc<ToolSet>,
}
@ -31,6 +32,7 @@ impl<S> Layer<S> for ToolLayer {
}
}
#[derive(Clone)]
pub struct ToolLayerService<S> {
inner: S,
tools: Arc<ToolSet>,