mirror of https://github.com/0xplaygrounds/rig
fix: satisfy ci
This commit is contained in:
parent
0ba10f0e3e
commit
bf652f47e8
|
@ -1,11 +1,14 @@
|
|||
use rig::{
|
||||
completion::CompletionRequestBuilder,
|
||||
completion::{CompletionRequestBuilder, ToolDefinition},
|
||||
middlewares::{
|
||||
completion::{CompletionLayer, CompletionService},
|
||||
tools::ToolLayer,
|
||||
},
|
||||
providers::openai::Client,
|
||||
tool::{Tool, ToolSet},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower::Service;
|
||||
use tower::ServiceBuilder;
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -13,11 +16,11 @@ async fn main() {
|
|||
let client = Client::from_env();
|
||||
let model = client.completion_model("gpt-4o");
|
||||
|
||||
let comp_layer = CompletionLayer::builder(model).build();
|
||||
let tool_layer = ToolLayer::new(vec![Add]);
|
||||
let service = CompletionService::new(model);
|
||||
let comp_layer = CompletionLayer::builder(model.clone()).build();
|
||||
let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add]));
|
||||
let service = CompletionService::new(model.clone());
|
||||
|
||||
let service = ServiceBuilder::new()
|
||||
let mut service = ServiceBuilder::new()
|
||||
.layer(comp_layer)
|
||||
.layer(tool_layer)
|
||||
.service(service);
|
||||
|
@ -32,6 +35,16 @@ async fn main() {
|
|||
#[derive(Deserialize, Serialize)]
|
||||
struct Add;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Math error")]
|
||||
struct MathError;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OperationArgs {
|
||||
x: i32,
|
||||
y: i32,
|
||||
}
|
||||
|
||||
impl Tool for Add {
|
||||
const NAME: &'static str = "add";
|
||||
|
||||
|
@ -40,7 +53,7 @@ impl Tool for Add {
|
|||
type Output = i32;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
serde_json::from_value(json!({
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"name": "add",
|
||||
"description": "Add x and y together",
|
||||
"parameters": {
|
||||
|
|
|
@ -15,6 +15,7 @@ use crate::{
|
|||
OneOrMany,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionLayer<M> {
|
||||
model: M,
|
||||
preamble: Option<String>,
|
||||
|
@ -155,6 +156,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionLayerService<M, S> {
|
||||
inner: S,
|
||||
model: M,
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
|||
message::{AssistantContent, Text},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwaitApprovalLayer<T> {
|
||||
integration: T,
|
||||
}
|
||||
|
@ -43,6 +44,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwaitApprovalLayerService<S, T> {
|
||||
inner: S,
|
||||
integration: T,
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::{
|
|||
|
||||
use super::ServiceError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExtractorLayer<T> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
@ -41,6 +42,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExtractorLayerService<S, T> {
|
||||
inner: S,
|
||||
_t: PhantomData<T>,
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
|
|||
|
||||
use tower::{Layer, Service};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolLayer {
|
||||
tools: Arc<ToolSet>,
|
||||
}
|
||||
|
@ -31,6 +32,7 @@ impl<S> Layer<S> for ToolLayer {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolLayerService<S> {
|
||||
inner: S,
|
||||
tools: Arc<ToolSet>,
|
||||
|
|
Loading…
Reference in New Issue