mirror of https://github.com/0xplaygrounds/rig
fix: satisfy ci
This commit is contained in:
parent
0ba10f0e3e
commit
bf652f47e8
|
@ -1,11 +1,14 @@
|
||||||
use rig::{
|
use rig::{
|
||||||
completion::CompletionRequestBuilder,
|
completion::{CompletionRequestBuilder, ToolDefinition},
|
||||||
middlewares::{
|
middlewares::{
|
||||||
completion::{CompletionLayer, CompletionService},
|
completion::{CompletionLayer, CompletionService},
|
||||||
tools::ToolLayer,
|
tools::ToolLayer,
|
||||||
},
|
},
|
||||||
providers::openai::Client,
|
providers::openai::Client,
|
||||||
|
tool::{Tool, ToolSet},
|
||||||
};
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tower::Service;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -13,11 +16,11 @@ async fn main() {
|
||||||
let client = Client::from_env();
|
let client = Client::from_env();
|
||||||
let model = client.completion_model("gpt-4o");
|
let model = client.completion_model("gpt-4o");
|
||||||
|
|
||||||
let comp_layer = CompletionLayer::builder(model).build();
|
let comp_layer = CompletionLayer::builder(model.clone()).build();
|
||||||
let tool_layer = ToolLayer::new(vec![Add]);
|
let tool_layer = ToolLayer::new(ToolSet::from_tools(vec![Add]));
|
||||||
let service = CompletionService::new(model);
|
let service = CompletionService::new(model.clone());
|
||||||
|
|
||||||
let service = ServiceBuilder::new()
|
let mut service = ServiceBuilder::new()
|
||||||
.layer(comp_layer)
|
.layer(comp_layer)
|
||||||
.layer(tool_layer)
|
.layer(tool_layer)
|
||||||
.service(service);
|
.service(service);
|
||||||
|
@ -32,6 +35,16 @@ async fn main() {
|
||||||
#[derive(Deserialize, Serialize)]
|
#[derive(Deserialize, Serialize)]
|
||||||
struct Add;
|
struct Add;
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Math error")]
|
||||||
|
struct MathError;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OperationArgs {
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
}
|
||||||
|
|
||||||
impl Tool for Add {
|
impl Tool for Add {
|
||||||
const NAME: &'static str = "add";
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
@ -40,7 +53,7 @@ impl Tool for Add {
|
||||||
type Output = i32;
|
type Output = i32;
|
||||||
|
|
||||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
serde_json::from_value(json!({
|
serde_json::from_value(serde_json::json!({
|
||||||
"name": "add",
|
"name": "add",
|
||||||
"description": "Add x and y together",
|
"description": "Add x and y together",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
|
|
|
@ -15,6 +15,7 @@ use crate::{
|
||||||
OneOrMany,
|
OneOrMany,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CompletionLayer<M> {
|
pub struct CompletionLayer<M> {
|
||||||
model: M,
|
model: M,
|
||||||
preamble: Option<String>,
|
preamble: Option<String>,
|
||||||
|
@ -155,6 +156,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CompletionLayerService<M, S> {
|
pub struct CompletionLayerService<M, S> {
|
||||||
inner: S,
|
inner: S,
|
||||||
model: M,
|
model: M,
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
||||||
message::{AssistantContent, Text},
|
message::{AssistantContent, Text},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct AwaitApprovalLayer<T> {
|
pub struct AwaitApprovalLayer<T> {
|
||||||
integration: T,
|
integration: T,
|
||||||
}
|
}
|
||||||
|
@ -43,6 +44,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct AwaitApprovalLayerService<S, T> {
|
pub struct AwaitApprovalLayerService<S, T> {
|
||||||
inner: S,
|
inner: S,
|
||||||
integration: T,
|
integration: T,
|
||||||
|
|
|
@ -9,6 +9,7 @@ use crate::{
|
||||||
|
|
||||||
use super::ServiceError;
|
use super::ServiceError;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ExtractorLayer<T> {
|
pub struct ExtractorLayer<T> {
|
||||||
_t: PhantomData<T>,
|
_t: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
@ -41,6 +42,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ExtractorLayerService<S, T> {
|
pub struct ExtractorLayerService<S, T> {
|
||||||
inner: S,
|
inner: S,
|
||||||
_t: PhantomData<T>,
|
_t: PhantomData<T>,
|
||||||
|
|
|
@ -8,6 +8,7 @@ use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
|
||||||
|
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ToolLayer {
|
pub struct ToolLayer {
|
||||||
tools: Arc<ToolSet>,
|
tools: Arc<ToolSet>,
|
||||||
}
|
}
|
||||||
|
@ -31,6 +32,7 @@ impl<S> Layer<S> for ToolLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ToolLayerService<S> {
|
pub struct ToolLayerService<S> {
|
||||||
inner: S,
|
inner: S,
|
||||||
tools: Arc<ToolSet>,
|
tools: Arc<ToolSet>,
|
||||||
|
|
Loading…
Reference in New Issue