mirror of https://github.com/0xplaygrounds/rig
feat: a bunch of stuff
This commit is contained in:
parent
ed0baca252
commit
49221ab848
|
@ -0,0 +1,67 @@
|
||||||
|
use rig::{
|
||||||
|
completion::CompletionRequestBuilder,
|
||||||
|
middlewares::{
|
||||||
|
completion::{CompletionLayer, CompletionService},
|
||||||
|
tools::ToolLayer,
|
||||||
|
},
|
||||||
|
providers::openai::Client,
|
||||||
|
};
|
||||||
|
use tower::ServiceBuilder;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let client = Client::from_env();
|
||||||
|
let model = client.completion_model("gpt-4o");
|
||||||
|
|
||||||
|
let comp_layer = CompletionLayer::builder(model).build();
|
||||||
|
let tool_layer = ToolLayer::new(vec![Add]);
|
||||||
|
let service = CompletionService::new(model);
|
||||||
|
|
||||||
|
let service = ServiceBuilder::new()
|
||||||
|
.layer(comp_layer)
|
||||||
|
.layer(tool_layer)
|
||||||
|
.service(service);
|
||||||
|
|
||||||
|
let comp_request = CompletionRequestBuilder::new(model, "Please calculate 5+5 for me").build();
|
||||||
|
|
||||||
|
let res = service.call(comp_request).await.unwrap();
|
||||||
|
|
||||||
|
println!("{res:?}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct Add;
|
||||||
|
|
||||||
|
impl Tool for Add {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
|
||||||
|
type Error = MathError;
|
||||||
|
type Args = OperationArgs;
|
||||||
|
type Output = i32;
|
||||||
|
|
||||||
|
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||||
|
serde_json::from_value(json!({
|
||||||
|
"name": "add",
|
||||||
|
"description": "Add x and y together",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to add"
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("Tool Definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||||
|
let result = args.x + args.y;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
|
@ -235,6 +235,7 @@ pub trait CompletionModel: Clone + Send + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Struct representing a general completion request that can be sent to a completion model provider.
|
/// Struct representing a general completion request that can be sent to a completion model provider.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
/// The prompt to be sent to the completion model provider
|
/// The prompt to be sent to the completion model provider
|
||||||
pub prompt: Message,
|
pub prompt: Message,
|
||||||
|
|
|
@ -0,0 +1,532 @@
|
||||||
|
use std::{future::Future, pin::Pin, task::Poll};
|
||||||
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::{
|
||||||
|
CompletionModel, CompletionRequest, CompletionRequestBuilder, Document, ToolDefinition,
|
||||||
|
},
|
||||||
|
message::Message,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A Tower layer to finish building your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct FinishBuilding;
|
||||||
|
|
||||||
|
impl<S> Layer<S> for FinishBuilding {
|
||||||
|
type Service = FinishBuildingService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
FinishBuildingService { inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to finish building your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly. Use [`FinishBuilding`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct FinishBuildingService<S> {
|
||||||
|
inner: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for FinishBuildingService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequest;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.build())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add documents to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct DocumentsLayer {
|
||||||
|
documents: Vec<Document>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DocumentsLayer {
|
||||||
|
pub fn new(documents: Vec<Document>) -> Self {
|
||||||
|
Self { documents }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for DocumentsLayer {
|
||||||
|
type Service = DocumentsLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
DocumentsLayerService {
|
||||||
|
inner,
|
||||||
|
documents: self.documents.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add documents to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`DocumentsLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct DocumentsLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
documents: Vec<Document>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for DocumentsLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let documents = self.documents.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.documents(documents))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add a temperature value to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct TemperatureLayer {
|
||||||
|
temperature: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemperatureLayer {
|
||||||
|
pub fn new(temperature: f64) -> Self {
|
||||||
|
Self { temperature }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for TemperatureLayer {
|
||||||
|
type Service = TemperatureLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
TemperatureLayerService {
|
||||||
|
inner,
|
||||||
|
temperature: self.temperature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add a temperature value to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`TemperatureLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct TemperatureLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
temperature: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for TemperatureLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let temperature = self.temperature;
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.temperature(temperature))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add tools to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct ToolsLayer {
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolsLayer {
|
||||||
|
pub fn new(tools: Vec<ToolDefinition>) -> Self {
|
||||||
|
Self { tools }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for ToolsLayer {
|
||||||
|
type Service = ToolsLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
ToolsLayerService {
|
||||||
|
inner,
|
||||||
|
tools: self.tools.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add tools to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`ToolsLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct ToolsLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for ToolsLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let tools = self.tools.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.tools(tools))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add a preamble ("system message") to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct PreambleLayer {
|
||||||
|
preamble: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PreambleLayer {
|
||||||
|
pub fn new(preamble: String) -> Self {
|
||||||
|
Self { preamble }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for PreambleLayer {
|
||||||
|
type Service = PreambleLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
PreambleLayerService {
|
||||||
|
inner,
|
||||||
|
preamble: self.preamble.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add a preamble ("system message") to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`PreambleLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct PreambleLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
preamble: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for PreambleLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let preamble = self.preamble.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.preamble(preamble))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct AdditionalParamsLayer {
|
||||||
|
additional_params: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdditionalParamsLayer {
|
||||||
|
pub fn new(additional_params: serde_json::Value) -> Self {
|
||||||
|
Self { additional_params }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for AdditionalParamsLayer {
|
||||||
|
type Service = AdditionalParamsLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
AdditionalParamsLayerService {
|
||||||
|
inner,
|
||||||
|
additional_params: self.additional_params.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add additional parameters to your `CompletionRequestBuilder` (which are not already covered by the other parameters).
|
||||||
|
/// Not intended to be used directly - use [`AdditionalParamsLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct AdditionalParamsLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
additional_params: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for AdditionalParamsLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let additional_params = self.additional_params.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.additional_params(additional_params))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add a maximum tokens parameter to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct MaxTokensLayer {
|
||||||
|
max_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaxTokensLayer {
|
||||||
|
pub fn new(max_tokens: u64) -> Self {
|
||||||
|
Self { max_tokens }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for MaxTokensLayer {
|
||||||
|
type Service = MaxTokensLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
MaxTokensLayerService {
|
||||||
|
inner,
|
||||||
|
max_tokens: self.max_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add a maximum tokens parameter to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`MaxTokensLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct MaxTokensLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
max_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for MaxTokensLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let max_tokens = self.max_tokens;
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.max_tokens(max_tokens))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower layer to add a chat history to your `CompletionRequestBuilder`.
|
||||||
|
/// Intended to be used with [`CompletionRequestBuilderService`].
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct ChatHistoryLayer {
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatHistoryLayer {
|
||||||
|
pub fn new(chat_history: Vec<Message>) -> Self {
|
||||||
|
Self { chat_history }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Layer<S> for ChatHistoryLayer {
|
||||||
|
type Service = ChatHistoryLayerService<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
ChatHistoryLayerService {
|
||||||
|
inner,
|
||||||
|
chat_history: self.chat_history.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Tower service to add a chat history to your `CompletionRequestBuilder`.
|
||||||
|
/// Not intended to be used directly - use [`ChatHistoryLayer`] instead.
|
||||||
|
///
|
||||||
|
/// See [`CompletionRequestBuilderService`] for usage.
|
||||||
|
pub struct ChatHistoryLayerService<S> {
|
||||||
|
inner: S,
|
||||||
|
chat_history: Vec<Message>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M, Msg, S> Service<(M, Msg)> for ChatHistoryLayerService<S>
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
S: Service<(M, Msg), Response = CompletionRequestBuilder<M>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: (M, Msg)) -> Self::Future {
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
let chat_history = self.chat_history.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let Ok(res) = inner.call(req).await else {
|
||||||
|
todo!("Handle error properly");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(res.messages(chat_history))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `tower` service for building a completion request.
|
||||||
|
/// Note that the last layer to be applied goes first and the core service is placed last (so it gets executedd from bottom to top).
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```rust
|
||||||
|
/// let agent = rig::providers::openai::Client::from_env();
|
||||||
|
/// let model = agent.completion_model("gpt-4o");
|
||||||
|
///
|
||||||
|
/// let service = tower::ServiceBuilder::new()
|
||||||
|
/// .layer(FinishBuilding)
|
||||||
|
/// .layer(TemperatureLayer::new(0.0))
|
||||||
|
/// .layer(PreambleLayer::new("You are a helpful assistant"))
|
||||||
|
/// .service(CompletionRequestBuilderService);
|
||||||
|
///
|
||||||
|
/// let request = service.call((model, "Hello world!".to_string())).await.unwrap();
|
||||||
|
///
|
||||||
|
/// let res = request.send().await.unwrap();
|
||||||
|
///
|
||||||
|
/// println!("{res:?}");
|
||||||
|
/// ```
|
||||||
|
pub struct CompletionRequestBuilderService;
|
||||||
|
|
||||||
|
impl<M, Msg> Service<(M, Msg)> for CompletionRequestBuilderService
|
||||||
|
where
|
||||||
|
M: CompletionModel + Send + 'static,
|
||||||
|
Msg: Into<Message> + Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = CompletionRequestBuilder<M>;
|
||||||
|
type Error = ();
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, (model, prompt): (M, Msg)) -> Self::Future {
|
||||||
|
Box::pin(async move { Ok(CompletionRequestBuilder::new(model, prompt)) })
|
||||||
|
}
|
||||||
|
}
|
|
@ -224,6 +224,7 @@ where
|
||||||
/// A completion model as a Tower service.
|
/// A completion model as a Tower service.
|
||||||
///
|
///
|
||||||
/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service.
|
/// This allows you to use an LLM model (or client) essentially anywhere you'd use a regular Tower layer, like in an Axum web service.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CompletionService<M> {
|
pub struct CompletionService<M> {
|
||||||
/// The model itself.
|
/// The model itself.
|
||||||
model: M,
|
model: M,
|
||||||
|
|
|
@ -1,68 +1,54 @@
|
||||||
use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
|
use serde::Deserialize;
|
||||||
|
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
|
||||||
use schemars::JsonSchema;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionModel, CompletionRequest, CompletionResponse},
|
completion::{CompletionRequest, CompletionResponse},
|
||||||
extractor::{ExtractionError, Extractor},
|
|
||||||
message::{AssistantContent, Text},
|
message::{AssistantContent, Text},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ExtractorLayer<M, T>
|
use super::ServiceError;
|
||||||
where
|
|
||||||
M: CompletionModel,
|
pub struct ExtractorLayer<T> {
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
|
_t: PhantomData<T>,
|
||||||
{
|
|
||||||
ext: Arc<Extractor<M, T>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, T> ExtractorLayer<M, T>
|
impl<T> ExtractorLayer<T>
|
||||||
where
|
where
|
||||||
M: CompletionModel,
|
T: for<'a> Deserialize<'a>,
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
|
|
||||||
{
|
{
|
||||||
pub fn new(ext: Extractor<M, T>) -> Self {
|
pub fn new() -> Self {
|
||||||
Self { ext: Arc::new(ext) }
|
Self { _t: PhantomData }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, M, T> Layer<S> for ExtractorLayer<M, T>
|
impl<S, T> Layer<S> for ExtractorLayer<T>
|
||||||
where
|
where
|
||||||
M: CompletionModel + 'static,
|
T: for<'a> Deserialize<'a>,
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
|
|
||||||
{
|
{
|
||||||
type Service = ExtractorLayerService<S, M, T>;
|
type Service = ExtractorLayerService<S, T>;
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
ExtractorLayerService {
|
ExtractorLayerService { inner, _t: self._t }
|
||||||
inner,
|
|
||||||
ext: Arc::clone(&self.ext),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ExtractorLayerService<S, M, T>
|
pub struct ExtractorLayerService<S, T> {
|
||||||
where
|
|
||||||
M: CompletionModel + 'static,
|
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
inner: S,
|
inner: S,
|
||||||
ext: Arc<Extractor<M, T>>,
|
_t: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, M, T> Service<CompletionRequest> for ExtractorLayerService<S, M, T>
|
impl<S, F, T> Service<CompletionRequest> for ExtractorLayerService<S, T>
|
||||||
where
|
where
|
||||||
S: Service<CompletionRequest, Response = CompletionResponse<M::Response>>
|
S: Service<CompletionRequest, Response = CompletionResponse<F>, Error = ServiceError>
|
||||||
+ Clone
|
+ Clone
|
||||||
+ Send
|
+ Send
|
||||||
+ 'static,
|
+ 'static,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
M: CompletionModel + 'static,
|
F: 'static,
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
|
T: for<'a> Deserialize<'a> + 'static,
|
||||||
{
|
{
|
||||||
type Response = T;
|
type Response = T;
|
||||||
type Error = ExtractionError;
|
type Error = ServiceError;
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
fn poll_ready(
|
fn poll_ready(
|
||||||
|
@ -73,63 +59,18 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||||
let ext = self.ext.clone();
|
|
||||||
let mut inner = self.inner.clone();
|
let mut inner = self.inner.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let Ok(res) = inner.call(req).await else {
|
let res = inner.call(req).await?;
|
||||||
todo!("Properly handle error");
|
|
||||||
};
|
|
||||||
|
|
||||||
let AssistantContent::Text(Text { text }) = res.choice.first() else {
|
let AssistantContent::Text(Text { text }) = res.choice.first() else {
|
||||||
todo!("Handle errors properly");
|
todo!("Handle errors properly");
|
||||||
};
|
};
|
||||||
|
|
||||||
ext.extract(&text).await
|
let obj = serde_json::from_str::<T>(&text)?;
|
||||||
|
|
||||||
|
Ok(obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ExtractorService<M, T>
|
|
||||||
where
|
|
||||||
M: CompletionModel + 'static,
|
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
ext: Arc<Extractor<M, T>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M, T> ExtractorService<M, T>
|
|
||||||
where
|
|
||||||
M: CompletionModel + 'static,
|
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
pub fn new(ext: Extractor<M, T>) -> Self {
|
|
||||||
Self { ext: Arc::new(ext) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M, T> Service<CompletionRequest> for ExtractorService<M, T>
|
|
||||||
where
|
|
||||||
M: CompletionModel,
|
|
||||||
T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
|
|
||||||
{
|
|
||||||
type Response = T;
|
|
||||||
type Error = ExtractionError;
|
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
||||||
|
|
||||||
fn poll_ready(
|
|
||||||
&mut self,
|
|
||||||
_cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
|
||||||
let ext = self.ext.clone();
|
|
||||||
let Some(req) = req.prompt.rag_text() else {
|
|
||||||
todo!("Handle error properly");
|
|
||||||
};
|
|
||||||
|
|
||||||
Box::pin(async move { ext.extract(&req).await })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{completion::CompletionError, extractor::ExtractionError, tool::ToolSetError};
|
use crate::{
|
||||||
|
completion::CompletionError, extractor::ExtractionError, tool::ToolSetError,
|
||||||
|
vector_store::VectorStoreError,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod build_completions;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod components;
|
pub mod components;
|
||||||
pub mod extractor;
|
pub mod extractor;
|
||||||
|
pub mod parallel;
|
||||||
pub mod rag;
|
pub mod rag;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|
||||||
|
@ -16,4 +21,17 @@ pub enum ServiceError {
|
||||||
CompletionError(#[from] CompletionError),
|
CompletionError(#[from] CompletionError),
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
ToolSetError(#[from] ToolSetError),
|
ToolSetError(#[from] ToolSetError),
|
||||||
|
#[error("{0}")]
|
||||||
|
VectorStoreError(#[from] VectorStoreError),
|
||||||
|
#[error("Value required but was null: {0}")]
|
||||||
|
RequiredOptionNotFound(String),
|
||||||
|
#[error("{0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceError {
|
||||||
|
pub fn required_option_not_exists<S: Into<String>>(val: S) -> Self {
|
||||||
|
let val: String = val.into();
|
||||||
|
Self::RequiredOptionNotFound(val)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
use std::{future::Future, pin::Pin};
|
||||||
|
|
||||||
|
use tower::Service;
|
||||||
|
|
||||||
|
use crate::completion::CompletionRequest;
|
||||||
|
|
||||||
|
use super::ServiceError;
|
||||||
|
|
||||||
|
pub struct Stackable<A, B> {
|
||||||
|
pub inner: A,
|
||||||
|
pub outer: B,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A, B> Stackable<A, B> {
|
||||||
|
pub fn new(inner: A, outer: B) -> Self {
|
||||||
|
Self { inner, outer }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_values(self) -> (A, B) {
|
||||||
|
(self.inner, self.outer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ParallelService<S, T> {
|
||||||
|
first_service: S,
|
||||||
|
second_service: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, T> ParallelService<S, T>
|
||||||
|
where
|
||||||
|
S: Service<CompletionRequest>,
|
||||||
|
T: Service<CompletionRequest>,
|
||||||
|
{
|
||||||
|
pub fn new(first_service: S, second_service: T) -> Self {
|
||||||
|
Self {
|
||||||
|
first_service,
|
||||||
|
second_service,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, T> Service<CompletionRequest> for ParallelService<S, T>
|
||||||
|
where
|
||||||
|
S: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
|
||||||
|
S::Future: Send,
|
||||||
|
S::Response: Send + 'static,
|
||||||
|
T: Service<CompletionRequest, Error = ServiceError> + Clone + Send + 'static,
|
||||||
|
T::Future: Send,
|
||||||
|
T::Response: Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = Stackable<S::Response, T::Response>;
|
||||||
|
type Error = ServiceError;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
&mut self,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
std::task::Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||||
|
let mut first = self.first_service.clone();
|
||||||
|
let mut second = self.second_service.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let res1 = first.call(req.clone()).await?;
|
||||||
|
let res2 = second.call(req.clone()).await?;
|
||||||
|
|
||||||
|
let stackable = Stackable::new(res1, res2);
|
||||||
|
|
||||||
|
Ok(stackable)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! parallel_service {
|
||||||
|
($service1:tt, $service2:tt) => {
|
||||||
|
$crate::pipeline::parallel::ParallelService::new($service1, $service2)
|
||||||
|
};
|
||||||
|
($op1:tt $(, $ops:tt)*) => {
|
||||||
|
$crate::pipeline::parallel::ParallelService::new(
|
||||||
|
$service1,
|
||||||
|
$crate::parallel_op!($($ops),*)
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
|
@ -9,10 +9,9 @@ use std::{
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tower::Service;
|
use tower::Service;
|
||||||
|
|
||||||
use crate::{
|
use crate::{completion::CompletionRequest, vector_store::VectorStoreIndex};
|
||||||
completion::CompletionRequest,
|
|
||||||
vector_store::{VectorStoreError, VectorStoreIndex},
|
use super::ServiceError;
|
||||||
};
|
|
||||||
|
|
||||||
pub struct RagService<V, T> {
|
pub struct RagService<V, T> {
|
||||||
vector_index: Arc<V>,
|
vector_index: Arc<V>,
|
||||||
|
@ -39,7 +38,7 @@ where
|
||||||
T: Serialize + for<'a> Deserialize<'a> + Send,
|
T: Serialize + for<'a> Deserialize<'a> + Send,
|
||||||
{
|
{
|
||||||
type Response = RagResult<T>;
|
type Response = RagResult<T>;
|
||||||
type Error = VectorStoreError;
|
type Error = ServiceError;
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
@ -49,11 +48,16 @@ where
|
||||||
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
fn call(&mut self, req: CompletionRequest) -> Self::Future {
|
||||||
let vector_index = self.vector_index.clone();
|
let vector_index = self.vector_index.clone();
|
||||||
let num_results = self.num_results;
|
let num_results = self.num_results;
|
||||||
let Some(prompt) = req.prompt.rag_text() else {
|
|
||||||
todo!("Handle error properly");
|
|
||||||
};
|
|
||||||
|
|
||||||
Box::pin(async move { vector_index.top_n(&prompt, num_results).await })
|
Box::pin(async move {
|
||||||
|
let Some(prompt) = req.prompt.rag_text() else {
|
||||||
|
return Err(ServiceError::required_option_not_exists("rag_text"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let res = vector_index.top_n(&prompt, num_results).await?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::{CompletionRequest, CompletionResponse},
|
completion::{CompletionRequest, CompletionResponse},
|
||||||
message::{AssistantContent, Message, ToolResultContent, UserContent},
|
message::{AssistantContent, Message, ToolResultContent},
|
||||||
providers::{self},
|
|
||||||
tool::{ToolSet, ToolSetError},
|
tool::{ToolSet, ToolSetError},
|
||||||
OneOrMany,
|
OneOrMany,
|
||||||
};
|
};
|
||||||
|
@ -83,49 +82,3 @@ where
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ToolService {
|
|
||||||
tools: Arc<ToolSet>,
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIResponse = CompletionResponse<providers::openai::CompletionResponse>;
|
|
||||||
|
|
||||||
impl Service<OpenAIResponse> for ToolService {
|
|
||||||
type Response = Message;
|
|
||||||
type Error = ToolSetError;
|
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
||||||
|
|
||||||
fn poll_ready(
|
|
||||||
&mut self,
|
|
||||||
_cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, req: OpenAIResponse) -> Self::Future {
|
|
||||||
let tools = self.tools.clone();
|
|
||||||
|
|
||||||
Box::pin(async move {
|
|
||||||
let crate::message::AssistantContent::ToolCall(tool_call) = req.choice.first() else {
|
|
||||||
unimplemented!("handle error");
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(res) = tools
|
|
||||||
.call(
|
|
||||||
&tool_call.function.name,
|
|
||||||
tool_call.function.arguments.to_string(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
else {
|
|
||||||
todo!("Implement proper error handling");
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Message::User {
|
|
||||||
content: OneOrMany::one(UserContent::tool_result(
|
|
||||||
tool_call.id,
|
|
||||||
OneOrMany::one(ToolResultContent::text(res)),
|
|
||||||
)),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue